diff --git a/changes/1104-dmontagu.md b/changes/1104-dmontagu.md new file mode 100644 index 0000000000..eb975379d9 --- /dev/null +++ b/changes/1104-dmontagu.md @@ -0,0 +1 @@ +Add support for nested generic models diff --git a/docs/examples/models_generics_nested.py b/docs/examples/models_generics_nested.py new file mode 100644 index 0000000000..d59010261b --- /dev/null +++ b/docs/examples/models_generics_nested.py @@ -0,0 +1,21 @@ +from typing import Generic, TypeVar + +from pydantic import ValidationError +from pydantic.generics import GenericModel + +T = TypeVar('T') + +class InnerT(GenericModel, Generic[T]): + inner: T + +class OuterT(GenericModel, Generic[T]): + outer: T + nested: InnerT[T] + +nested = InnerT[int](inner=1) +print(OuterT[int](outer=1, nested=nested)) +try: + nested = InnerT[str](inner='a') + print(OuterT[int](outer='a', nested=nested)) +except ValidationError as e: + print(e) diff --git a/docs/examples/models_generics_typevars.py b/docs/examples/models_generics_typevars.py new file mode 100644 index 0000000000..ea15619b19 --- /dev/null +++ b/docs/examples/models_generics_typevars.py @@ -0,0 +1,24 @@ +from typing import Generic, TypeVar + +from pydantic import ValidationError +from pydantic.generics import GenericModel + +AT = TypeVar('AT') +BT = TypeVar('BT') + +class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + +print(Model(a='a', b='a')) + +IntT = TypeVar('IntT', bound=int) +typevar_model = Model[int, IntT] +print(typevar_model(a=1, b=1)) +try: + typevar_model(a='a', b='a') +except ValidationError as exc: + print(exc) + +concrete_model = typevar_model[int] +print(concrete_model(a=1, b=1)) diff --git a/docs/usage/models.md b/docs/usage/models.md index 807f2eeb0d..e81a2d2bb1 100644 --- a/docs/usage/models.md +++ b/docs/usage/models.md @@ -303,6 +303,26 @@ If the name of the concrete subclasses is important, you can also override the d ``` _(This script is complete, it should run "as is")_ +Using the same TypeVar in nested models allows you to enforce typing relationships at different points in your model: + +```py +{!.tmp_examples/models_generics_nested.py!} +``` +_(This script is complete, it should run "as is")_ + +Pydantic also treats `GenericModel` similarly to how it treats built-in generic types like `List` and `Dict` when it +comes to leaving them unparameterized, or using bounded `TypeVar` instances: + +* If you don't specify parameters before instantiating the generic model, they will be treated as `Any` +* You can parametrize models with one or more *bounded* parameters to add subclass checks + +Also, like `List` and `Dict`, any parameters specified using a `TypeVar` can later be substituted with concrete types. + +```py +{!.tmp_examples/models_generics_typevars.py!} +``` +_(This script is complete, it should run "as is")_ + ## Dynamic model creation There are some occasions where the shape of a model is not known until runtime. For this *pydantic* provides diff --git a/pydantic/generics.py b/pydantic/generics.py index 278dc791ae..0ff2f0566f 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -1,21 +1,24 @@ -from typing import Any, ClassVar, Dict, Generic, Tuple, Type, TypeVar, Union, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, TypeVar, Union, cast, get_type_hints from .class_validators import gather_all_validators +from .fields import FieldInfo, ModelField from .main import BaseModel, create_model +from .utils import lenient_issubclass _generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {} GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type class GenericModel(BaseModel): __slots__ = () __concrete__: ClassVar[bool] = False - def __new__(cls, *args: Any, **kwargs: Any) -> Any: - if cls.__concrete__: - return super().__new__(cls) - else: - raise TypeError(f'Type {cls.__name__} cannot be used without generic parameters, e.g. {cls.__name__}[T]') + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] # Setting the return type as Type[Any] instead of Type[BaseModel] prevents PyCharm warnings def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[Type[Any], ...]]) -> Type[Any]: @@ -28,11 +31,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T params = (params,) if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): # type: ignore raise TypeError(f'Type parameters should be placed on typing.Generic, not GenericModel') - if Generic not in cls.__bases__: + if not hasattr(cls, '__parameters__'): raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized') check_parameters_count(cls, params) - typevars_map: Dict[Any, Any] = dict(zip(cls.__parameters__, params)) # type: ignore + typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params)) type_hints = get_type_hints(cls).items() instance_type_hints = {k: v for k, v in type_hints if getattr(v, '__origin__', None) is not ClassVar} concrete_type_hints: Dict[str, Type[Any]] = { @@ -41,19 +44,25 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T model_name = cls.__concrete_name__(params) validators = gather_all_validators(cls) - fields: Dict[str, Tuple[Type[Any], Any]] = { - k: (v, cls.__fields__[k].field_info) for k, v in concrete_type_hints.items() if k in cls.__fields__ - } - created_model = create_model( - model_name=model_name, - __module__=cls.__module__, - __base__=cls, - __config__=None, - __validators__=validators, - **fields, + fields = _build_generic_fields(cls.__fields__, concrete_type_hints, typevars_map) + created_model = cast( + Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes + create_model( + model_name=model_name, + __module__=cls.__module__, + __base__=cls, + __config__=None, + __validators__=validators, + **fields, + ), ) created_model.Config = cls.Config - created_model.__concrete__ = True # type: ignore + concrete = all(not _is_typevar(v) for v in concrete_type_hints.values()) + created_model.__concrete__ = concrete + if not concrete: + parameters = tuple(v for v in concrete_type_hints.values() if _is_typevar(v)) + parameters = tuple({k: None for k in parameters}.keys()) # get unique params while maintaining order + created_model.__parameters__ = parameters _generic_types_cache[(cls, params)] = created_model if len(params) == 1: _generic_types_cache[(cls, params[0])] = created_model @@ -78,7 +87,30 @@ def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]: def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None: actual = len(parameters) - expected = len(cls.__parameters__) # type: ignore + expected = len(cls.__parameters__) if actual != expected: description = 'many' if actual > expected else 'few' raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}') + + +def _build_generic_fields( + raw_fields: Dict[str, ModelField], + concrete_type_hints: Dict[str, Type[Any]], + typevars_map: Dict[TypeVarType, Type[Any]], +) -> Dict[str, Tuple[Type[Any], FieldInfo]]: + return { + k: (_parameterize_generic_field(v, typevars_map), raw_fields[k].field_info) + for k, v in concrete_type_hints.items() + if k in raw_fields + } + + +def _parameterize_generic_field(field_type: Type[Any], typevars_map: Dict[TypeVarType, Type[Any]]) -> Type[Any]: + if lenient_issubclass(field_type, GenericModel) and not field_type.__concrete__: + parameters = tuple(typevars_map.get(param, param) for param in field_type.__parameters__) + field_type = field_type[parameters] + return field_type + + +def _is_typevar(v: Any) -> bool: + return isinstance(v, TypeVar) # type: ignore diff --git a/tests/test_generics.py b/tests/test_generics.py index 72d11b5aa1..2680b5a4ba 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -248,25 +248,6 @@ class Config: result.data = 2 -@skip_36 -def test_generic_instantiation_error(): - with pytest.raises(TypeError) as exc_info: - GenericModel() - assert str(exc_info.value) == 'Type GenericModel cannot be used without generic parameters, e.g. GenericModel[T]' - - -@skip_36 -def test_parameterized_generic_instantiation_error(): - data_type = TypeVar('data_type') - - class Result(GenericModel, Generic[data_type]): - data: data_type - - with pytest.raises(TypeError) as exc_info: - Result(data=1) - assert str(exc_info.value) == 'Type Result cannot be used without generic parameters, e.g. Result[T]' - - @skip_36 def test_deep_generic(): T = TypeVar('T') @@ -444,3 +425,157 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: assert repr(MyModel[int](value=1)) == 'OptionalIntWrapper(value=1)' assert repr(MyModel[str](value=None)) == 'OptionalStrWrapper(value=None)' + + +@skip_36 +def test_nested(): + AT = TypeVar('AT') + + class InnerT(GenericModel, Generic[AT]): + a: AT + + inner_int = InnerT[int](a=8) + inner_str = InnerT[str](a='ate') + inner_dict_any = InnerT[Any](a={}) + inner_int_any = InnerT[Any](a=7) + + class OuterT_SameType(GenericModel, Generic[AT]): + i: InnerT[AT] + + OuterT_SameType[int](i=inner_int) + OuterT_SameType[str](i=inner_str) + OuterT_SameType[int](i=inner_int_any) # ensure parsing the broader inner type works + + with pytest.raises(ValidationError) as exc_info: + OuterT_SameType[int](i=inner_str) + assert exc_info.value.errors() == [ + {'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + with pytest.raises(ValidationError) as exc_info: + OuterT_SameType[int](i=inner_dict_any) + assert exc_info.value.errors() == [ + {'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + +@skip_36 +def test_partial_specification(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + partial_model = Model[int, BT] + concrete_model = partial_model[str] + concrete_model(a=1, b='abc') + with pytest.raises(ValidationError) as exc_info: + concrete_model(a='abc', b=None) + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, + {'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + ] + + +@skip_36 +def test_partial_specification_name(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + partial_model = Model[int, BT] + assert partial_model.__name__ == 'Model[int, BT]' + concrete_model = partial_model[str] + assert concrete_model.__name__ == 'Model[int, BT][str]' + + +@skip_36 +def test_partial_specification_instantiation(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + partial_model = Model[int, BT] + partial_model(a=1, b=2) + + partial_model(a=1, b='a') + + with pytest.raises(ValidationError) as exc_info: + partial_model(a='a', b=2) + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + +@skip_36 +def test_partial_specification_instantiation_bounded(): + AT = TypeVar('AT') + BT = TypeVar('BT', bound=int) + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + Model(a=1, b=1) + with pytest.raises(ValidationError) as exc_info: + Model(a=1, b='a') + assert exc_info.value.errors() == [ + {'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + partial_model = Model[int, BT] + partial_model(a=1, b=1) + with pytest.raises(ValidationError) as exc_info: + partial_model(a=1, b='a') + assert exc_info.value.errors() == [ + {'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + +@skip_36 +def test_typevar_parametrization(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + CT = TypeVar('CT', bound=int) + DT = TypeVar('DT', bound=int) + + with pytest.raises(ValidationError) as exc_info: + Model[CT, DT](a='a', b='b') + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, + {'loc': ('b',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, + ] + + +@skip_36 +def test_multiple_specification(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + CT = TypeVar('CT') + partial_model = Model[CT, CT] + concrete_model = partial_model[str] + + with pytest.raises(ValidationError) as exc_info: + concrete_model(a=None, b=None) + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + {'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + ]