Skip to content

Commit

Permalink
Add support for nested generics
Browse files Browse the repository at this point in the history
  • Loading branch information
David Montague committed Dec 17, 2019
1 parent 5510a13 commit fe5c43c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 15 deletions.
68 changes: 53 additions & 15 deletions pydantic/generics.py
@@ -1,16 +1,25 @@
from typing import Any, ClassVar, Dict, Generic, Tuple, Type, TypeVar, Union, get_type_hints
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type, TypeVar, Union, cast, get_type_hints

from .class_validators import gather_all_validators
from .fields import FieldInfo, ModelField
from .main import BaseModel, create_model
from .utils import lenient_issubclass

_generic_types_cache: Dict[Tuple[Type[Any], Union[Any, Tuple[Any, ...]]], Type[BaseModel]] = {}
GenericModelT = TypeVar('GenericModelT', bound='GenericModel')
TypeVarType = Any # since mypy doesn't allow the use of TypeVar as a type


class GenericModel(BaseModel):
__slots__ = ()
__concrete__: ClassVar[bool] = False

if TYPE_CHECKING:
# Putting this in a TYPE_CHECKING block allows us to replace `if Generic not in cls.__bases__` with
# `not hasattr(cls, "__parameters__")`. This means we don't need to force non-concrete subclasses of
# `GenericModel` to also inherit from `Generic`, which would require changes to the use of `create_model` below.
__parameters__: ClassVar[Tuple[TypeVarType, ...]]

def __new__(cls, *args: Any, **kwargs: Any) -> Any:
if cls.__concrete__:
return super().__new__(cls)
Expand All @@ -28,11 +37,11 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T
params = (params,)
if cls is GenericModel and any(isinstance(param, TypeVar) for param in params): # type: ignore
raise TypeError(f'Type parameters should be placed on typing.Generic, not GenericModel')
if Generic not in cls.__bases__:
if not hasattr(cls, '__parameters__'):
raise TypeError(f'Type {cls.__name__} must inherit from typing.Generic before being parameterized')

check_parameters_count(cls, params)
typevars_map: Dict[Any, Any] = dict(zip(cls.__parameters__, params)) # type: ignore
typevars_map: Dict[TypeVarType, Type[Any]] = dict(zip(cls.__parameters__, params))
type_hints = get_type_hints(cls).items()
instance_type_hints = {k: v for k, v in type_hints if getattr(v, '__origin__', None) is not ClassVar}
concrete_type_hints: Dict[str, Type[Any]] = {
Expand All @@ -41,19 +50,25 @@ def __class_getitem__(cls: Type[GenericModelT], params: Union[Type[Any], Tuple[T

model_name = cls.__concrete_name__(params)
validators = gather_all_validators(cls)
fields: Dict[str, Tuple[Type[Any], Any]] = {
k: (v, cls.__fields__[k].field_info) for k, v in concrete_type_hints.items() if k in cls.__fields__
}
created_model = create_model(
model_name=model_name,
__module__=cls.__module__,
__base__=cls,
__config__=None,
__validators__=validators,
**fields,
fields = _build_generic_fields(cls.__fields__, concrete_type_hints, typevars_map)
created_model = cast(
Type[GenericModel], # casting ensures mypy is aware of the __concrete__ and __parameters__ attributes
create_model(
model_name=model_name,
__module__=cls.__module__,
__base__=cls,
__config__=None,
__validators__=validators,
**fields,
),
)
created_model.Config = cls.Config
created_model.__concrete__ = True # type: ignore
concrete = all(not isinstance(v, TypeVar) for v in concrete_type_hints.values()) # type: ignore
created_model.__concrete__ = concrete
if not concrete:
parameters = tuple(v for v in concrete_type_hints.values() if _is_typevar(v))
parameters = tuple({k: None for k in parameters}.keys()) # get unique params while maintaining order
created_model.__parameters__ = parameters
_generic_types_cache[(cls, params)] = created_model
if len(params) == 1:
_generic_types_cache[(cls, params[0])] = created_model
Expand All @@ -78,7 +93,30 @@ def resolve_type_hint(type_: Any, typevars_map: Dict[Any, Any]) -> Type[Any]:

def check_parameters_count(cls: Type[GenericModel], parameters: Tuple[Any, ...]) -> None:
actual = len(parameters)
expected = len(cls.__parameters__) # type: ignore
expected = len(cls.__parameters__)
if actual != expected:
description = 'many' if actual > expected else 'few'
raise TypeError(f'Too {description} parameters for {cls.__name__}; actual {actual}, expected {expected}')


def _build_generic_fields(
raw_fields: Dict[str, ModelField],
concrete_type_hints: Dict[str, Type[Any]],
typevars_map: Dict[TypeVarType, Type[Any]],
) -> Dict[str, Tuple[Type[Any], FieldInfo]]:
return {
k: (_parameterize_generic_field(v, typevars_map), raw_fields[k].field_info)
for k, v in concrete_type_hints.items()
if k in raw_fields
}


def _parameterize_generic_field(field_type: Type[Any], typevars_map: Dict[TypeVarType, Type[Any]]) -> Type[Any]:
if lenient_issubclass(field_type, GenericModel) and not field_type.__concrete__:
parameters = tuple(typevars_map.get(param, param) for param in field_type.__parameters__)
field_type = field_type[parameters]
return field_type


def _is_typevar(v: Any) -> bool:
return isinstance(v, TypeVar) # type: ignore
73 changes: 73 additions & 0 deletions tests/test_generics.py
Expand Up @@ -444,3 +444,76 @@ def __concrete_name__(cls: Type[Any], params: Tuple[Type[Any], ...]) -> str:

assert repr(MyModel[int](value=1)) == 'OptionalIntWrapper(value=1)'
assert repr(MyModel[str](value=None)) == 'OptionalStrWrapper(value=None)'


@skip_36
def test_nested():
AT = TypeVar('AT')

class InnerT(GenericModel, Generic[AT]):
a: AT

inner_int = InnerT[int](a=8)
inner_str = InnerT[str](a='ate')
inner_dict_any = InnerT[Any](a={})
inner_int_any = InnerT[Any](a=7)

class OuterT_SameType(GenericModel, Generic[AT]):
i: InnerT[AT]

OuterT_SameType[int](i=inner_int)
OuterT_SameType[str](i=inner_str)
OuterT_SameType[int](i=inner_int_any) # ensure parsing the broader inner type works

with pytest.raises(ValidationError) as exc_info:
OuterT_SameType[int](i=inner_str)
assert exc_info.value.errors() == [
{'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]

with pytest.raises(ValidationError) as exc_info:
OuterT_SameType[int](i=inner_dict_any)
assert exc_info.value.errors() == [
{'loc': ('i', 'a'), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}
]


@skip_36
def test_partial_specification():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

partial_model = Model[int, BT]
concrete_model = partial_model[str]
concrete_model(a=1, b='abc')
with pytest.raises(ValidationError) as exc_info:
concrete_model(a='abc', b=None)
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'},
{'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
]


@skip_36
def test_multiple_specification():
AT = TypeVar('AT')
BT = TypeVar('BT')

class Model(GenericModel, Generic[AT, BT]):
a: AT
b: BT

CT = TypeVar('CT')
partial_model = Model[CT, CT]
concrete_model = partial_model[str]

with pytest.raises(ValidationError) as exc_info:
concrete_model(a=None, b=None)
assert exc_info.value.errors() == [
{'loc': ('a',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
{'loc': ('b',), 'msg': 'none is not an allowed value', 'type': 'type_error.none.not_allowed'},
]

0 comments on commit fe5c43c

Please sign in to comment.