From 39f1ae65eb2d11c770f7a9406fe7bb15cd39c640 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 21 Sep 2023 07:55:35 -0500 Subject: [PATCH] Improve cache based on PR feedback --- pydantic/_internal/_core_utils.py | 45 +++++++++++++++++--------- pydantic/_internal/_generate_schema.py | 18 +++++------ tests/test_generics.py | 9 ++---- tests/test_internal.py | 16 ++++++++- 4 files changed, 55 insertions(+), 33 deletions(-) diff --git a/pydantic/_internal/_core_utils.py b/pydantic/_internal/_core_utils.py index 1db168cbc7d..8293a34232b 100644 --- a/pydantic/_internal/_core_utils.py +++ b/pydantic/_internal/_core_utils.py @@ -42,6 +42,14 @@ _DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache' NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union' +"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of +schema building because one of it's members refers to a definition that was not yet defined when the union +was first encountered. +""" +HAS_INVALID_SCHEMAS_METADATA_KEY = 'pydantic.internal.invalid' +"""Used to mark a schema that is invalid because it refers to a definition that was not yet defined when the +schema was first encountered. +""" def is_core_schema( @@ -136,11 +144,11 @@ def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_sche def define_expected_missing_refs( schema: core_schema.CoreSchema, allowed_missing_refs: set[str] -) -> tuple[core_schema.CoreSchema, bool]: +) -> core_schema.CoreSchema | None: if not allowed_missing_refs: # in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema # this is a common case (will be hit for all non-generic models), so it's worth optimizing for - return schema, False + return None refs = collect_definitions(schema).keys() @@ -149,29 +157,34 @@ def define_expected_missing_refs( definitions: list[core_schema.CoreSchema] = [ # TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail # Issue: https://github.com/pydantic/pydantic-core/issues/619 - core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True}) + core_schema.none_schema(ref=ref, metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True}) for ref in expected_missing_refs ] - return core_schema.definitions_schema(schema, definitions), True - return schema, False + return core_schema.definitions_schema(schema, definitions) + return None -def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]: - invalid_schemas: list[core_schema.CoreSchema] = [] +def collect_invalid_schemas(schema: core_schema.CoreSchema) -> bool: + invalid = False def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: - metadata = s.get('metadata', None) - if metadata is None: - return recurse(s, _is_schema_valid) - invalid = metadata.get('invalid', None) - if invalid is False: - return s - elif invalid is True: - invalid_schemas.append(s) + nonlocal invalid + if 'metadata' in s: + metadata = s['metadata'] + if HAS_INVALID_SCHEMAS_METADATA_KEY in metadata: + invalid = metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] + if invalid is True: + invalid = True + return s return recurse(s, _is_schema_valid) walk_core_schema(schema, _is_schema_valid) - return invalid_schemas + if 'metadata' in schema: + metadata = schema['metadata'] + metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = invalid + else: + schema['metadata'] = {HAS_INVALID_SCHEMAS_METADATA_KEY: invalid} + return invalid T = TypeVar('T') diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index e479e631825..0c5d874a944 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -50,6 +50,7 @@ build_metadata_dict, ) from ._core_utils import ( + HAS_INVALID_SCHEMAS_METADATA_KEY, NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, CoreSchemaOrField, define_expected_missing_refs, @@ -513,15 +514,14 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema: model_name=cls.__name__, ) inner_schema = apply_validators(fields_schema, decorators.root_validators.values(), None) - inner_schema, has_invalid_schema = define_expected_missing_refs( - inner_schema, recursively_defined_type_refs() - ) - inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') - - if has_invalid_schema: + new_inner_schema = define_expected_missing_refs(inner_schema, recursively_defined_type_refs()) + if new_inner_schema is not None: + inner_schema = new_inner_schema self._has_invalid_schema = True - - metadata['invalid'] = has_invalid_schema + metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = True + else: + metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = False + inner_schema = apply_model_validators(inner_schema, model_validators, 'inner') model_schema = core_schema.model_schema( cls, @@ -655,7 +655,7 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]: def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: metadata = schema.setdefault('metadata', {}) metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union - metadata['invalid'] = self._has_invalid_schema + metadata[HAS_INVALID_SCHEMAS_METADATA_KEY] = self._has_invalid_schema return schema def _generate_schema(self, obj: Any) -> core_schema.CoreSchema: diff --git a/tests/test_generics.py b/tests/test_generics.py index b8bbf92cb28..14be0624dc6 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -1782,7 +1782,7 @@ class M2(BaseModel, Generic[V3]): M1 = module.M1 # assert M1.__pydantic_core_schema__ == {} - assert collect_invalid_schemas(M1.__pydantic_core_schema__) == [] + assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False def test_generic_recursive_models_complicated(create_module): @@ -1842,7 +1842,7 @@ class M2(BaseModel, Generic[V3]): M1 = module.M1 - assert collect_invalid_schemas(M1.__pydantic_core_schema__) == [] + assert collect_invalid_schemas(M1.__pydantic_core_schema__) is False def test_generic_recursive_models_in_container(create_module): @@ -1863,11 +1863,6 @@ class MyGenericModel(BaseModel, Generic[T]): assert type(instance.foobar[0]) == MyGenericModel[int] -def test_schema_is_valid(): - assert not collect_invalid_schemas(core_schema.none_schema()) - assert collect_invalid_schemas(core_schema.nullable_schema(core_schema.int_schema(metadata={'invalid': True}))) - - def test_generic_enum(): T = TypeVar('T') diff --git a/tests/test_internal.py b/tests/test_internal.py index 9053457a68a..5b2705dbf83 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -7,7 +7,13 @@ from pydantic_core import CoreSchema, SchemaValidator from pydantic_core import core_schema as cs -from pydantic._internal._core_utils import Walk, simplify_schema_references, walk_core_schema +from pydantic._internal._core_utils import ( + HAS_INVALID_SCHEMAS_METADATA_KEY, + Walk, + collect_invalid_schemas, + simplify_schema_references, + walk_core_schema, +) from pydantic._internal._repr import Representation @@ -192,3 +198,11 @@ class Obj(Representation): ' ) (Obj)', ] assert list(obj.__rich_repr__()) == [('int_attr', 42), ('str_attr', 'Marvin')] + + +def test_schema_is_valid(): + assert collect_invalid_schemas(cs.none_schema()) is False + assert ( + collect_invalid_schemas(cs.nullable_schema(cs.int_schema(metadata={HAS_INVALID_SCHEMAS_METADATA_KEY: True}))) + is True + )