Skip to content

Commit

Permalink
Add support for typing.Self (fix #5992) (#9023)
Browse files Browse the repository at this point in the history
Co-authored-by: yfares1 <yfares1@bloomberg.net>
  • Loading branch information
Youssefares and yfares1 committed Mar 25, 2024
1 parent 33a275a commit 426402a
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 9 deletions.
51 changes: 42 additions & 9 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -76,10 +76,8 @@
from ._fields import collect_dataclass_fields, get_type_hints_infer_globalns
from ._forward_ref import PydanticRecursiveRef
from ._generics import get_standard_typevars_map, has_instance_in_type, recursively_defined_type_refs, replace_types
from ._schema_generation_shared import (
CallbackGetCoreSchemaHandler,
)
from ._typing_extra import is_finalvar
from ._schema_generation_shared import CallbackGetCoreSchemaHandler
from ._typing_extra import is_finalvar, is_self_type
from ._utils import lenient_issubclass

if TYPE_CHECKING:
Expand Down Expand Up @@ -311,6 +309,7 @@ class GenerateSchema:
'_types_namespace_stack',
'_typevars_map',
'field_name_stack',
'model_type_stack',
'defs',
)

Expand All @@ -325,19 +324,22 @@ def __init__(
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self.field_name_stack = _FieldNameStack()
self.model_type_stack = _ModelTypeStack()
self.defs = _Definitions()

@classmethod
def __from_parent(
cls,
config_wrapper_stack: ConfigWrapperStack,
types_namespace_stack: TypesNamespaceStack,
model_type_stack: _ModelTypeStack,
typevars_map: dict[Any, Any] | None,
defs: _Definitions,
) -> GenerateSchema:
obj = cls.__new__(cls)
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj.model_type_stack = model_type_stack
obj._typevars_map = typevars_map
obj.field_name_stack = _FieldNameStack()
obj.defs = defs
Expand All @@ -357,6 +359,7 @@ def _current_generate_schema(self) -> GenerateSchema:
return cls.__from_parent(
self._config_wrapper_stack,
self._types_namespace_stack,
self.model_type_stack,
self._typevars_map,
self.defs,
)
Expand Down Expand Up @@ -622,6 +625,8 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
decide whether to use a `__pydantic_core_schema__` attribute, or generate a fresh schema.
"""
# avoid calling `__get_pydantic_core_schema__` if we've already visited this object
if is_self_type(obj):
obj = self.model_type_stack.get()
with self.defs.get_schema_or_ref(obj) as (_, maybe_schema):
if maybe_schema is not None:
return maybe_schema
Expand Down Expand Up @@ -735,7 +740,8 @@ def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
from ..main import BaseModel

if lenient_issubclass(obj, BaseModel):
return self._model_schema(obj)
with self.model_type_stack.push(obj):
return self._model_schema(obj)

if isinstance(obj, PydanticRecursiveRef):
return core_schema.definition_reference_schema(schema_ref=obj.type_ref)
Expand Down Expand Up @@ -815,7 +821,6 @@ def match_type(self, obj: Any) -> core_schema.CoreSchema: # noqa: C901

if _typing_extra.is_dataclass(obj):
return self._dataclass_schema(obj, None)

res = self._get_prepare_pydantic_annotations_for_known_type(obj, ())
if res is not None:
source_type, annotations = res
Expand Down Expand Up @@ -1199,7 +1204,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co
"""
from ..fields import FieldInfo

with self.defs.get_schema_or_ref(typed_dict_cls) as (typed_dict_ref, maybe_schema):
with self.model_type_stack.push(typed_dict_cls), self.defs.get_schema_or_ref(typed_dict_cls) as (
typed_dict_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema

Expand Down Expand Up @@ -1286,7 +1294,10 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co

def _namedtuple_schema(self, namedtuple_cls: Any, origin: Any) -> core_schema.CoreSchema:
"""Generate schema for a NamedTuple."""
with self.defs.get_schema_or_ref(namedtuple_cls) as (namedtuple_ref, maybe_schema):
with self.model_type_stack.push(namedtuple_cls), self.defs.get_schema_or_ref(namedtuple_cls) as (
namedtuple_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema
typevars_map = get_standard_typevars_map(namedtuple_cls)
Expand Down Expand Up @@ -1475,7 +1486,10 @@ def _dataclass_schema(
self, dataclass: type[StandardDataclass], origin: type[StandardDataclass] | None
) -> core_schema.CoreSchema:
"""Generate schema for a dataclass."""
with self.defs.get_schema_or_ref(dataclass) as (dataclass_ref, maybe_schema):
with self.model_type_stack.push(dataclass), self.defs.get_schema_or_ref(dataclass) as (
dataclass_ref,
maybe_schema,
):
if maybe_schema is not None:
return maybe_schema

Expand Down Expand Up @@ -2254,3 +2268,22 @@ def get(self) -> str | None:
return self._stack[-1]
else:
return None


class _ModelTypeStack:
__slots__ = ('_stack',)

def __init__(self) -> None:
self._stack: list[type] = []

@contextmanager
def push(self, type_obj: type) -> Iterator[None]:
self._stack.append(type_obj)
yield
self._stack.pop()

def get(self) -> type | None:
if self._stack:
return self._stack[-1]
else:
return None
5 changes: 5 additions & 0 deletions pydantic/_internal/_typing_extra.py
Expand Up @@ -487,3 +487,8 @@ def is_generic_alias(type_: type[Any]) -> bool:

def is_generic_alias(type_: type[Any]) -> bool:
return isinstance(type_, typing._GenericAlias) # type: ignore


def is_self_type(tp: Any) -> bool:
"""Check if a given class is a Self type (from `typing` or `typing_extensions`)"""
return isinstance(tp, typing_base) and getattr(tp, '_name', None) == 'Self'
171 changes: 171 additions & 0 deletions tests/test_types_self.py
@@ -0,0 +1,171 @@
import dataclasses
import typing
from typing import List, Optional, Union

import pytest
import typing_extensions
from typing_extensions import NamedTuple, TypedDict

from pydantic import BaseModel, Field, TypeAdapter, ValidationError


@pytest.fixture(
name='Self',
params=[
pytest.param(typing, id='typing.Self'),
pytest.param(typing_extensions, id='t_e.Self'),
],
)
def fixture_self_all(request):
try:
return request.param.Self
except AttributeError:
pytest.skip(f'Self is not available from {request.param}')


def test_recursive_model(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

assert SelfRef(data=1, ref={'data': 2}).model_dump() == {'data': 1, 'ref': {'data': 2, 'ref': None}}


def test_recursive_model_invalid(Self):
class SelfRef(BaseModel):
data: int
ref: typing.Optional[Self] = None

with pytest.raises(
ValidationError,
match=r'ref\.ref\s+Input should be a valid dictionary or instance of SelfRef \[type=model_type,',
):
SelfRef(data=1, ref={'data': 2, 'ref': 3}).model_dump()


def test_recursive_model_with_subclass(Self):
"""Self refs should be valid and should reference the correct class in covariant direction"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None, 'y': 4}, # SubSelfRef.ref: SubSelfRef
'y': 2,
}
assert SelfRef(x=1, ref=SubSelfRef(x=2, y=3)).model_dump() == {
'x': 1,
'ref': {'x': 2, 'ref': None},
} # SelfRef.ref: SelfRef


def test_recursive_model_with_subclass_invalid(Self):
"""Self refs are invalid in contravariant direction"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int

with pytest.raises(
ValidationError,
match=r'ref\s+Input should be a valid dictionary or instance of SubSelfRef \[type=model_type,',
):
SubSelfRef(x=1, ref=SelfRef(x=2), y=3).model_dump()


def test_recursive_model_with_subclass_override(Self):
"""Self refs should be overridable"""

class SelfRef(BaseModel):
x: int
ref: Self | None = None

class SubSelfRef(SelfRef):
y: int
ref: Optional[Union[SelfRef, Self]] = None

assert SubSelfRef(x=1, ref=SubSelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None, 'y': 4},
'y': 2,
}
assert SubSelfRef(x=1, ref=SelfRef(x=3, y=4), y=2).model_dump() == {
'x': 1,
'ref': {'x': 3, 'ref': None},
'y': 2,
}


def test_self_type_with_field(Self):
with pytest.raises(TypeError, match=r'The following constraints cannot be applied.*\'gt\''):

class SelfRef(BaseModel):
x: int
refs: typing.List[Self] = Field(..., gt=0)


def test_self_type_json_schema(Self):
class SelfRef(BaseModel):
x: int
refs: Optional[List[Self]] = []

assert SelfRef.model_json_schema() == {
'$defs': {
'SelfRef': {
'properties': {
'x': {'title': 'X', 'type': 'integer'},
'refs': {
'anyOf': [{'items': {'$ref': '#/$defs/SelfRef'}, 'type': 'array'}, {'type': 'null'}],
'default': [],
'title': 'Refs',
},
},
'required': ['x'],
'title': 'SelfRef',
'type': 'object',
}
},
'allOf': [{'$ref': '#/$defs/SelfRef'}],
}


def test_self_type_in_named_tuple(Self):
class SelfRefNamedTuple(NamedTuple):
x: int
ref: Self | None

ta = TypeAdapter(SelfRefNamedTuple)
assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == (1, (2, None))


def test_self_type_in_typed_dict(Self):
class SelfRefTypedDict(TypedDict):
x: int
ref: Self | None

ta = TypeAdapter(SelfRefTypedDict)
assert ta.validate_python({'x': 1, 'ref': {'x': 2, 'ref': None}}) == {'x': 1, 'ref': {'x': 2, 'ref': None}}


def test_self_type_in_dataclass(Self):
@dataclasses.dataclass(frozen=True)
class SelfRef:
x: int
ref: Self | None

class Model(BaseModel):
item: SelfRef

m = Model.model_validate({'item': {'x': 1, 'ref': {'x': 2, 'ref': None}}})
assert m.item.x == 1
assert m.item.ref.x == 2
with pytest.raises(dataclasses.FrozenInstanceError):
m.item.ref.x = 3

0 comments on commit 426402a

Please sign in to comment.