diff --git a/pydantic/generics.py b/pydantic/generics.py index 278dc791ae7..5183a444210 100644 --- a/pydantic/generics.py +++ b/pydantic/generics.py @@ -1,16 +1,25 @@ -from typing import Any, ClassVar, Dict, Generic, Tuple, Type, TypeVar, Union, get_type_hints +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, TypeVar, Union, cast, get_type_hints from .class_validators import gather_all_validators +from .fields import FieldInfo, ModelField from .main import BaseModel, create_model +from .utils import lenient_issubclass _generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {} GenericModelT = TypeVar('GenericModelT', bound='GenericModel') +TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type class GenericModel(BaseModel): __slots__ = () __concrete__: ClassVar[bool] = False + if TYPE_CHECKING: + # Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with + # `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of + # `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below. + __parameters__: ClassVar[Tuple[TypeVarType, ...]] + def __new__(cls, *args: Any, **kwargs: Any) -> Any: if cls.__concrete__: return super().__new__(cls) @@ -28,11 +37,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T params = (params,) if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): # type: ignore raise TypeError(f'Type parameters should be placed on typing.Generic, not GenericModel') - if Generic not in cls.__bases__: + if not hasattr(cls, '__parameters__'): raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized') check_parameters_count(cls, params) - typevars_map: Dict[Any, Any] = dict(zip(cls.__parameters__, params)) # type: ignore + typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params)) type_hints = get_type_hints(cls).items() instance_type_hints = {k: v for k, v in type_hints if getattr(v, '__origin__', None) is not ClassVar} concrete_type_hints: Dict[str, Type[Any]] = { @@ -41,19 +50,25 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T model_name = cls.__concrete_name__(params) validators = gather_all_validators(cls) - fields: Dict[str, Tuple[Type[Any], Any]] = { - k: (v, cls.__fields__[k].field_info) for k, v in concrete_type_hints.items() if k in cls.__fields__ - } - created_model = create_model( - model_name=model_name, - __module__=cls.__module__, - __base__=cls, - __config__=None, - __validators__=validators, - **fields, + fields = _build_generic_fields(cls.__fields__, concrete_type_hints, typevars_map) + created_model = cast( + Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes + create_model( + model_name=model_name, + __module__=cls.__module__, + __base__=cls, + __config__=None, + __validators__=validators, + **fields, + ), ) created_model.Config = cls.Config - created_model.__concrete__ = True # type: ignore + concrete = all(not isinstance(v, TypeVar) for v in concrete_type_hints.values()) # type: ignore + created_model.__concrete__ = concrete + if not concrete: + parameters = tuple(v for v in concrete_type_hints.values() if _is_typevar(v)) + parameters = tuple({k: None for k in parameters}.keys()) # get unique params while maintaining order + created_model.__parameters__ = parameters _generic_types_cache[(cls, params)] = created_model if len(params) == 1: _generic_types_cache[(cls, params[0])] = created_model @@ -78,7 +93,30 @@ def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]: def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None: actual = len(parameters) - expected = len(cls.__parameters__) # type: ignore + expected = len(cls.__parameters__) if actual != expected: description = 'many' if actual > expected else 'few' raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}') + + +def _build_generic_fields( + raw_fields: Dict[str, ModelField], + concrete_type_hints: Dict[str, Type[Any]], + typevars_map: Dict[TypeVarType, Type[Any]], +) -> Dict[str, Tuple[Type[Any], FieldInfo]]: + return { + k: (_parameterize_generic_field(v, typevars_map), raw_fields[k].field_info) + for k, v in concrete_type_hints.items() + if k in raw_fields + } + + +def _parameterize_generic_field(field_type: Type[Any], typevars_map: Dict[TypeVarType, Type[Any]]) -> Type[Any]: + if lenient_issubclass(field_type, GenericModel) and not field_type.__concrete__: + parameters = tuple(typevars_map.get(param, param) for param in field_type.__parameters__) + field_type = field_type[parameters] + return field_type + + +def _is_typevar(v: Any) -> bool: + return isinstance(v, TypeVar) # type: ignore diff --git a/tests/test_generics.py b/tests/test_generics.py index 72d11b5aa1b..d9b0c4292f2 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -444,3 +444,76 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str: assert repr(MyModel[int](value=1)) == 'OptionalIntWrapper(value=1)' assert repr(MyModel[str](value=None)) == 'OptionalStrWrapper(value=None)' + + +@skip_36 +def test_nested(): + AT = TypeVar('AT') + + class InnerT(GenericModel, Generic[AT]): + a: AT + + inner_int = InnerT[int](a=8) + inner_str = InnerT[str](a='ate') + inner_dict_any = InnerT[Any](a={}) + inner_int_any = InnerT[Any](a=7) + + class OuterT_SameType(GenericModel, Generic[AT]): + i: InnerT[AT] + + OuterT_SameType[int](i=inner_int) + OuterT_SameType[str](i=inner_str) + OuterT_SameType[int](i=inner_int_any) # ensure parsing the broader inner type works + + with pytest.raises(ValidationError) as exc_info: + OuterT_SameType[int](i=inner_str) + assert exc_info.value.errors() == [ + {'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + with pytest.raises(ValidationError) as exc_info: + OuterT_SameType[int](i=inner_dict_any) + assert exc_info.value.errors() == [ + {'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'} + ] + + +@skip_36 +def test_partial_specification(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + partial_model = Model[int, BT] + concrete_model = partial_model[str] + concrete_model(a=1, b='abc') + with pytest.raises(ValidationError) as exc_info: + concrete_model(a='abc', b=None) + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, + {'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + ] + + +@skip_36 +def test_multiple_specification(): + AT = TypeVar('AT') + BT = TypeVar('BT') + + class Model(GenericModel, Generic[AT, BT]): + a: AT + b: BT + + CT = TypeVar('CT') + partial_model = Model[CT, CT] + concrete_model = partial_model[str] + + with pytest.raises(ValidationError) as exc_info: + concrete_model(a=None, b=None) + assert exc_info.value.errors() == [ + {'loc': ('a',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + {'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'}, + ]