From 426402af51602aaa773e283c4f12ae17a36c38b5 Mon Sep 17 00:00:00 2001 From: Youssef Fares Date: Mon, 25 Mar 2024 21:57:37 +0000 Subject: [PATCH] Add support for typing.Self (fix #5992) (#9023) Co-authored-by: yfares1 --- pydantic/_internal/_generate_schema.py | 51 ++++++-- pydantic/_internal/_typing_extra.py | 5 + tests/test_types_self.py | 171 +++++++++++++++++++++++++ 3 files changed, 218 insertions(+), 9 deletions(-) create mode 100644 tests/test_types_self.py diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 0fd7da4037..71cf13affc 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -76,10 +76,8 @@ from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns from ._forward_ref import PydanticRecursiveRef from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types -from ._schema_generation_shared import ( - CallbackGetCoreSchemaHandler, -) -from ._typing_extra import is_finalvar +from ._schema_generation_shared import CallbackGetCoreSchemaHandler +from ._typing_extra import is_finalvar, is_self_type from ._utils import lenient_issubclass if TYPE_CHECKING: @@ -311,6 +309,7 @@ class GenerateSchema: '_types_namespace_stack', '_typevars_map', 'field_name_stack', + 'model_type_stack', 'defs', ) @@ -325,6 +324,7 @@ def __init__( self._types_namespace_stack = TypesNamespaceStack(types_namespace) self._typevars_map = typevars_map self.field_name_stack = _FieldNameStack() + self.model_type_stack = _ModelTypeStack() self.defs = _Definitions() @classmethod @@ -332,12 +332,14 @@ def __from_parent( cls, config_wrapper_stack: ConfigWrapperStack, types_namespace_stack: TypesNamespaceStack, + model_type_stack: _ModelTypeStack, typevars_map: dict[Any, Any] | None, defs: _Definitions, ) -> GenerateSchema: obj = cls.__new__(cls) obj._config_wrapper_stack = config_wrapper_stack obj._types_namespace_stack = types_namespace_stack + obj.model_type_stack = model_type_stack obj._typevars_map = typevars_map obj.field_name_stack = _FieldNameStack() obj.defs = defs @@ -357,6 +359,7 @@ def _current_generate_schema(self) -> GenerateSchema: return cls.__from_parent( self._config_wrapper_stack, self._types_namespace_stack, + self.model_type_stack, self._typevars_map, self.defs, ) @@ -622,6 +625,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema. """ # avoid calling `__get_pydantic_core_schema__` if we've already visited this object + if is_self_type(obj): + obj = self.model_type_stack.get() with self.defs.get_schema_or_ref(obj) as (_, maybe_schema): if maybe_schema is not None: return maybe_schema @@ -735,7 +740,8 @@ def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema: from ..main import BaseModel if lenient_issubclass(obj, BaseModel): - return self._model_schema(obj) + with self.model_type_stack.push(obj): + return self._model_schema(obj) if isinstance(obj, PydanticRecursiveRef): return core_schema.definition_reference_schema(schema_ref=obj.type_ref) @@ -815,7 +821,6 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901 if _typing_extra.is_dataclass(obj): return self._dataclass_schema(obj, None) - res = self._get_prepare_pydantic_annotations_for_known_type(obj, ()) if res is not None: source_type, annotations = res @@ -1199,7 +1204,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co """ from ..fields import FieldInfo - with self.defs.get_schema_or_ref(typed_dict_cls) as (typed_dict_ref, maybe_schema): + with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref(typed_dict_cls) as ( + typed_dict_ref, + maybe_schema, + ): if maybe_schema is not None: return maybe_schema @@ -1286,7 +1294,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema: """Generate schema for a NamedTuple.""" - with self.defs.get_schema_or_ref(namedtuple_cls) as (namedtuple_ref, maybe_schema): + with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref(namedtuple_cls) as ( + namedtuple_ref, + maybe_schema, + ): if maybe_schema is not None: return maybe_schema typevars_map = get_standard_typevars_map(namedtuple_cls) @@ -1475,7 +1486,10 @@ def _dataclass_schema( self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None ) -> core_schema.CoreSchema: """Generate schema for a dataclass.""" - with self.defs.get_schema_or_ref(dataclass) as (dataclass_ref, maybe_schema): + with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref(dataclass) as ( + dataclass_ref, + maybe_schema, + ): if maybe_schema is not None: return maybe_schema @@ -2254,3 +2268,22 @@ def get(self) -> str | None: return self._stack[-1] else: return None + + +class _ModelTypeStack: + __slots__ = ('_stack',) + + def __init__(self) -> None: + self._stack: list[type] = [] + + @contextmanager + def push(self, type_obj: type) -> Iterator[None]: + self._stack.append(type_obj) + yield + self._stack.pop() + + def get(self) -> type | None: + if self._stack: + return self._stack[-1] + else: + return None diff --git a/pydantic/_internal/_typing_extra.py b/pydantic/_internal/_typing_extra.py index c8be9e3e14..087353713f 100644 --- a/pydantic/_internal/_typing_extra.py +++ b/pydantic/_internal/_typing_extra.py @@ -487,3 +487,8 @@ def is_generic_alias(type_: type[Any]) -> bool: def is_generic_alias(type_: type[Any]) -> bool: return isinstance(type_, typing._GenericAlias) # type: ignore + + +def is_self_type(tp: Any) -> bool: + """Check if a given class is a Self type (from `typing` or `typing_extensions`)""" + return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self' diff --git a/tests/test_types_self.py b/tests/test_types_self.py new file mode 100644 index 0000000000..40bf8f1a9d --- /dev/null +++ b/tests/test_types_self.py @@ -0,0 +1,171 @@ +import dataclasses +import typing +from typing import List, Optional, Union + +import pytest +import typing_extensions +from typing_extensions import NamedTuple, TypedDict + +from pydantic import BaseModel, Field, TypeAdapter, ValidationError + + +@pytest.fixture( + name='Self', + params=[ + pytest.param(typing, id='typing.Self'), + pytest.param(typing_extensions, id='t_e.Self'), + ], +) +def fixture_self_all(request): + try: + return request.param.Self + except AttributeError: + pytest.skip(f'Self is not available from {request.param}') + + +def test_recursive_model(Self): + class SelfRef(BaseModel): + data: int + ref: typing.Optional[Self] = None + + assert SelfRef(data=1, ref={'data': 2}).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}} + + +def test_recursive_model_invalid(Self): + class SelfRef(BaseModel): + data: int + ref: typing.Optional[Self] = None + + with pytest.raises( + ValidationError, + match=r'ref\.ref\s+Input should be a valid dictionary or instance of SelfRef \[type=model_type,', + ): + SelfRef(data=1, ref={'data': 2, 'ref': 3}).model_dump() + + +def test_recursive_model_with_subclass(Self): + """Self refs should be valid and should reference the correct class in covariant direction""" + + class SelfRef(BaseModel): + x: int + ref: Self | None = None + + class SubSelfRef(SelfRef): + y: int + + assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == { + 'x': 1, + 'ref': {'x': 3, 'ref': None, 'y': 4}, # SubSelfRef.ref: SubSelfRef + 'y': 2, + } + assert SelfRef(x=1, ref=SubSelfRef(x=2, y=3)).model_dump() == { + 'x': 1, + 'ref': {'x': 2, 'ref': None}, + } # SelfRef.ref: SelfRef + + +def test_recursive_model_with_subclass_invalid(Self): + """Self refs are invalid in contravariant direction""" + + class SelfRef(BaseModel): + x: int + ref: Self | None = None + + class SubSelfRef(SelfRef): + y: int + + with pytest.raises( + ValidationError, + match=r'ref\s+Input should be a valid dictionary or instance of SubSelfRef \[type=model_type,', + ): + SubSelfRef(x=1, ref=SelfRef(x=2), y=3).model_dump() + + +def test_recursive_model_with_subclass_override(Self): + """Self refs should be overridable""" + + class SelfRef(BaseModel): + x: int + ref: Self | None = None + + class SubSelfRef(SelfRef): + y: int + ref: Optional[Union[SelfRef, Self]] = None + + assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == { + 'x': 1, + 'ref': {'x': 3, 'ref': None, 'y': 4}, + 'y': 2, + } + assert SubSelfRef(x=1, ref=SelfRef(x=3, y=4), y=2).model_dump() == { + 'x': 1, + 'ref': {'x': 3, 'ref': None}, + 'y': 2, + } + + +def test_self_type_with_field(Self): + with pytest.raises(TypeError, match=r'The following constraints cannot be applied.*\'gt\''): + + class SelfRef(BaseModel): + x: int + refs: typing.List[Self] = Field(..., gt=0) + + +def test_self_type_json_schema(Self): + class SelfRef(BaseModel): + x: int + refs: Optional[List[Self]] = [] + + assert SelfRef.model_json_schema() == { + '$defs': { + 'SelfRef': { + 'properties': { + 'x': {'title': 'X', 'type': 'integer'}, + 'refs': { + 'anyOf': [{'items': {'$ref': '#/$defs/SelfRef'}, 'type': 'array'}, {'type': 'null'}], + 'default': [], + 'title': 'Refs', + }, + }, + 'required': ['x'], + 'title': 'SelfRef', + 'type': 'object', + } + }, + 'allOf': [{'$ref': '#/$defs/SelfRef'}], + } + + +def test_self_type_in_named_tuple(Self): + class SelfRefNamedTuple(NamedTuple): + x: int + ref: Self | None + + ta = TypeAdapter(SelfRefNamedTuple) + assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == (1, (2, None)) + + +def test_self_type_in_typed_dict(Self): + class SelfRefTypedDict(TypedDict): + x: int + ref: Self | None + + ta = TypeAdapter(SelfRefTypedDict) + assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == {'x': 1, 'ref': {'x': 2, 'ref': None}} + + +def test_self_type_in_dataclass(Self): + @dataclasses.dataclass(frozen=True) + class SelfRef: + x: int + ref: Self | None + + class Model(BaseModel): + item: SelfRef + + m = Model.model_validate({'item': {'x': 1, 'ref': {'x': 2, 'ref': None}}}) + assert m.item.x == 1 + assert m.item.ref.x == 2 + with pytest.raises(dataclasses.FrozenInstanceError): + m.item.ref.x = 3