diff --git a/pydantic/main.py b/pydantic/main.py index b6f0928bb20..fb57dc36a22 100644 --- a/pydantic/main.py +++ b/pydantic/main.py @@ -136,8 +136,6 @@ def is_valid_field(name: str) -> bool: def validate_custom_root_type(fields: Dict[str, ModelField]) -> None: if len(fields) > 1: raise ValueError('__root__ cannot be mixed with other fields') - if fields[ROOT_KEY].shape == SHAPE_MAPPING: - raise TypeError('custom root type cannot allow mapping') UNTOUCHED_TYPES = FunctionType, property, type, classmethod, staticmethod @@ -246,7 +244,24 @@ def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 '__repr__': Representation.__repr__, **{n: v for n, v in namespace.items() if n not in fields}, } - return super().__new__(mcs, name, bases, new_namespace, **kwargs) + cls = super().__new__(mcs, name, bases, new_namespace, **kwargs) + if _custom_root_type: + _setup_custom_root_type(cls) + return cls + + +def _setup_custom_root_type(cls: Type['Model']) -> None: + original_parse_obj = cls.parse_obj + + root_field = cls.__fields__[ROOT_KEY] + root_field_is_mapping = root_field.shape == SHAPE_MAPPING + + def parse_obj(obj: Any) -> 'Model': + if not isinstance(obj, dict) or root_field_is_mapping: + obj = {ROOT_KEY: obj} + return original_parse_obj(obj) + + setattr(cls, 'parse_obj', staticmethod(parse_obj)) class BaseModel(metaclass=ModelMetaclass): @@ -379,14 +394,11 @@ def json( @classmethod def parse_obj(cls: Type['Model'], obj: Any) -> 'Model': if not isinstance(obj, dict): - if cls.__custom_root_type__: - obj = {ROOT_KEY: obj} - else: - try: - obj = dict(obj) - except (TypeError, ValueError) as e: - exc = TypeError(f'{cls.__name__} expected dict not {type(obj).__name__}') - raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e + try: + obj = dict(obj) + except (TypeError, ValueError) as e: + exc = TypeError(f'{cls.__name__} expected dict not {type(obj).__name__}') + raise ValidationError([ErrorWrapper(exc, loc=ROOT_KEY)], cls) from e return cls(**obj) @classmethod diff --git a/tests/test_main.py b/tests/test_main.py index 29c7f948c4f..9b0a620af85 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -900,10 +900,10 @@ class MyModel(BaseModel): def test_parse_root_as_mapping(): - with pytest.raises(TypeError, match='custom root type cannot allow mapping'): + class MyModel(BaseModel): + __root__: Mapping[str, str] - class MyModel(BaseModel): - __root__: Mapping[str, str] + assert MyModel.parse_obj({1: 2}).__root__ == {'1': '2'} def test_untouched_types():