diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index e66282907c..e632ce2854 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -272,12 +272,35 @@ def _add_custom_serialization_from_json_encoders( return schema +TypesNamespace = Union[Dict[str, Any], None] + + +class TypesNamespaceStack: + """A stack of types namespaces.""" + + def __init__(self, types_namespace: TypesNamespace): + self._types_namespace_stack: list[TypesNamespace] = [types_namespace] + + @property + def tail(self) -> TypesNamespace: + return self._types_namespace_stack[-1] + + @contextmanager + def push(self, for_type: type[Any]): + types_namespace = {**_typing_extra.get_cls_types_namespace(for_type), **(self.tail or {})} + self._types_namespace_stack.append(types_namespace) + try: + yield + finally: + self._types_namespace_stack.pop() + + class GenerateSchema: """Generate core schema for a Pydantic model, dataclass and types like `str`, `datetime`, ... .""" __slots__ = ( '_config_wrapper_stack', - '_types_namespace', + '_types_namespace_stack', '_typevars_map', '_needs_apply_discriminated_union', '_has_invalid_schema', @@ -293,7 +316,7 @@ def __init__( ) -> None: # we need a stack for recursing into child models self._config_wrapper_stack = ConfigWrapperStack(config_wrapper) - self._types_namespace = types_namespace + self._types_namespace_stack = TypesNamespaceStack(types_namespace) self._typevars_map = typevars_map self._needs_apply_discriminated_union = False self._has_invalid_schema = False @@ -304,13 +327,13 @@ def __init__( def __from_parent( cls, config_wrapper_stack: ConfigWrapperStack, - types_namespace: dict[str, Any] | None, + types_namespace_stack: TypesNamespaceStack, typevars_map: dict[Any, Any] | None, defs: _Definitions, ) -> GenerateSchema: obj = cls.__new__(cls) obj._config_wrapper_stack = config_wrapper_stack - obj._types_namespace = types_namespace + obj._types_namespace_stack = types_namespace_stack obj._typevars_map = typevars_map obj._needs_apply_discriminated_union = False obj._has_invalid_schema = False @@ -322,12 +345,16 @@ def __from_parent( def _config_wrapper(self) -> ConfigWrapper: return self._config_wrapper_stack.tail + @property + def _types_namespace(self) -> dict[str, Any] | None: + return self._types_namespace_stack.tail + @property def _current_generate_schema(self) -> GenerateSchema: cls = self._config_wrapper.schema_generator or GenerateSchema return cls.__from_parent( self._config_wrapper_stack, - self._types_namespace, + self._types_namespace_stack, self._typevars_map, self.defs, ) @@ -524,7 +551,7 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: extras_schema = self.generate_schema(extra_items_type) break - with self._config_wrapper_stack.push(config_wrapper): + with self._config_wrapper_stack.push(config_wrapper), self._types_namespace_stack.push(cls): self = self._current_generate_schema if cls.__pydantic_root_model__: root_field = self._common_field_schema('root', fields['root'], decorators) @@ -1114,19 +1141,14 @@ def _type_alias_type_schema( origin = get_origin(obj) or obj - namespace = (self._types_namespace or {}).copy() - new_namespace = {**_typing_extra.get_cls_types_namespace(origin), **namespace} annotation = origin.__value__ - - self._types_namespace = new_namespace typevars_map = get_standard_typevars_map(obj) - - annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None) - annotation = replace_types(annotation, typevars_map) - schema = self.generate_schema(annotation) - assert schema['type'] != 'definitions' - schema['ref'] = ref # type: ignore - self._types_namespace = namespace or None + with self._types_namespace_stack.push(origin): + annotation = _typing_extra.eval_type_lenient(annotation, self._types_namespace, None) + annotation = replace_types(annotation, typevars_map) + schema = self.generate_schema(annotation) + assert schema['type'] != 'definitions' + schema['ref'] = ref # type: ignore self.defs.definitions[ref] = schema return core_schema.definition_reference_schema(ref) @@ -1173,7 +1195,7 @@ def _typed_dict_schema(self, typed_dict_cls: Any, origin: Any) -> core_schema.Co except AttributeError: config = None - with self._config_wrapper_stack.push(config): + with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(typed_dict_cls): core_config = self._config_wrapper.core_config(typed_dict_cls) self = self._current_generate_schema @@ -1427,7 +1449,7 @@ def _dataclass_schema( dataclass = origin config = getattr(dataclass, '__pydantic_config__', None) - with self._config_wrapper_stack.push(config): + with self._config_wrapper_stack.push(config), self._types_namespace_stack.push(dataclass): core_config = self._config_wrapper.core_config(dataclass) self = self._current_generate_schema diff --git a/tests/test_create_model.py b/tests/test_create_model.py index 91b9ad9676..fe78db8370 100644 --- a/tests/test_create_model.py +++ b/tests/test_create_model.py @@ -581,3 +581,23 @@ def test_json_schema_with_inner_models_with_duplicate_names(): 'title': 'a', 'type': 'object', } + + +def test_resolving_forward_refs_across_modules(create_module): + module = create_module( + # language=Python + """\ +from __future__ import annotations +from dataclasses import dataclass +from pydantic import BaseModel + +class X(BaseModel): + pass + +@dataclass +class Y: + x: X + """ + ) + Z = create_model('Z', y=(module.Y, ...)) + assert Z(y={'x': {}}).y is not None