From fcb97b4af47ad2d30d6e068bc0983fe0cda0be11 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 20 Sep 2023 09:41:31 -0400 Subject: [PATCH] Simplify flatteining and inlining of Coreschema --- pydantic/_internal/_core_utils.py | 164 ++++++++++++---------- pydantic/_internal/_dataclasses.py | 12 +- pydantic/_internal/_generate_schema.py | 3 - pydantic/_internal/_model_construction.py | 10 +- pydantic/_internal/_validate_call.py | 21 ++- pydantic/type_adapter.py | 7 +- tests/test_dataclasses.py | 4 +- tests/test_internal.py | 115 ++++----------- tests/test_main.py | 8 +- tests/test_model_signature.py | 6 +- tests/test_root_model.py | 5 - 11 files changed, 154 insertions(+), 201 deletions(-) diff --git a/pydantic/_internal/_core_utils.py b/pydantic/_internal/_core_utils.py index 2b82920fb7..5905490d77 100644 --- a/pydantic/_internal/_core_utils.py +++ b/pydantic/_internal/_core_utils.py @@ -5,7 +5,6 @@ Any, Callable, Hashable, - Iterable, TypeVar, Union, _GenericAlias, # type: ignore @@ -13,7 +12,7 @@ ) from pydantic_core import CoreSchema, core_schema -from typing_extensions import TypeAliasType, TypeGuard, get_args +from typing_extensions import TypeAliasType, TypedDict, TypeGuard, get_args from . import _repr @@ -40,6 +39,8 @@ _FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'} _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'} +_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache' + def is_core_schema( schema: CoreSchemaOrField, @@ -416,92 +417,122 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor return f(schema, _dispatch) -def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901 - all_defs: dict[str, core_schema.CoreSchema] = {} +class _DefinitionsState(TypedDict): + definitions: dict[str, core_schema.CoreSchema] + ref_counts: dict[str, int] + involved_in_recursion: dict[str, bool] + current_recursion_ref_count: dict[str, int] - def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema: - definitions = list(defs) - if definitions: - return core_schema.definitions_schema(schema=schema, definitions=definitions) - return schema + +def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901 + """Simplify schema references by: + 1. Inlining any definitions that are only referenced in one place and are not involved in a cycle. + 2. Removing any unused `ref` references from schemas. + """ + state = _DefinitionsState( + definitions={}, + ref_counts=defaultdict(int), + involved_in_recursion={}, + current_recursion_ref_count=defaultdict(int), + ) def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if 'metadata' in s: + definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None) + if definitions_cache is not None: + state['definitions'].update(definitions_cache['definitions']) + return s + if s['type'] == 'definitions': for definition in s['definitions']: ref = get_ref(definition) assert ref is not None - all_defs[ref] = recurse(definition, collect_refs) + state['definitions'][ref] = definition + recurse(definition, collect_refs) return recurse(s['schema'], collect_refs) else: ref = get_ref(s) if ref is not None: - all_defs[ref] = s - return recurse(s, collect_refs) - - schema = walk_core_schema(schema, collect_refs) - - def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: - if s['type'] == 'definitions': - # iterate ourselves, we don't want to flatten the actual defs! - definitions: list[CoreSchema] = s.pop('definitions') # type: ignore - schema: CoreSchema = s.pop('schema') # type: ignore - # remaining keys are optional like 'serialization' - schema: CoreSchema = {**schema, **s} # type: ignore - s['schema'] = recurse(schema, flatten_refs) - for definition in definitions: - recurse(definition, flatten_refs) # don't re-assign here! - return schema - else: - s = recurse(s, flatten_refs) - ref = get_ref(s) - if ref and ref in all_defs: - all_defs[ref] = s + state['definitions'][ref] = s + recurse(s, collect_refs) return core_schema.definition_reference_schema(schema_ref=ref) - return s - - schema = walk_core_schema(schema, flatten_refs) - - for def_schema in all_defs.values(): - walk_core_schema(def_schema, flatten_refs) - - if not inline: - return make_result(schema, all_defs.values()) + else: + return recurse(s, collect_refs) - ref_counts: defaultdict[str, int] = defaultdict(int) - involved_in_recursion: dict[str, bool] = {} - current_recursion_ref_count: defaultdict[str, int] = defaultdict(int) + schema = walk_core_schema(schema, collect_refs) def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: + if 'metadata' in s: + definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None) + if definitions_cache is not None: + for ref in definitions_cache['ref_counts']: + state['ref_counts'][ref] += definitions_cache['ref_counts'][ref] + # it's possible that a schema was seen before we hit the cache + # and also exists in the cache, in which case it is involved in a recursion + if state['current_recursion_ref_count'][ref] != 0: + state['involved_in_recursion'][ref] = True + # if it's involved in recursion in the inner schema mark it globally as involved in a recursion + for ref_in_recursion in definitions_cache['involved_in_recursion']: + if ref_in_recursion: + state['involved_in_recursion'][ref_in_recursion] = True + return s + if s['type'] != 'definition-ref': return recurse(s, count_refs) ref = s['schema_ref'] - ref_counts[ref] += 1 + state['ref_counts'][ref] += 1 - if ref_counts[ref] >= 2: + if state['ref_counts'][ref] >= 2: # If this model is involved in a recursion this should be detected # on its second encounter, we can safely stop the walk here. - if current_recursion_ref_count[ref] != 0: - involved_in_recursion[ref] = True + if state['current_recursion_ref_count'][ref] != 0: + state['involved_in_recursion'][ref] = True return s - current_recursion_ref_count[ref] += 1 - recurse(all_defs[ref], count_refs) - current_recursion_ref_count[ref] -= 1 + state['current_recursion_ref_count'][ref] += 1 + recurse(state['definitions'][ref], count_refs) + state['current_recursion_ref_count'][ref] -= 1 return s schema = walk_core_schema(schema, count_refs) - assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it' + assert all(c == 0 for c in state['current_recursion_ref_count'].values()), 'this is a bug! please report it' + + definitions_cache = _DefinitionsState( + definitions=state['definitions'], + ref_counts=dict(state['ref_counts']), + involved_in_recursion=state['involved_in_recursion'], + current_recursion_ref_count=dict(state['current_recursion_ref_count']), + ) + + def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool: + if state['ref_counts'][ref] > 1: + return False + if state['involved_in_recursion'].get(ref, False): + return False + if 'serialization' in s: + return False + if 'metadata' in s: + metadata = s['metadata'] + for k in ( + 'pydantic_js_functions', + 'pydantic_js_annotation_functions', + 'pydantic.internal.union_discriminator', + ): + if k in metadata: + # we need to keep this as a ref + return False + return True def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: if s['type'] == 'definition-ref': ref = s['schema_ref'] - # Check if the reference is only used once and not involved in recursion - if ref_counts[ref] <= 1 and not involved_in_recursion.get(ref, False): + # Check if the reference is only used once, not involved in recursion and does not have + # any extra keys (like 'serialization') + if can_be_inlined(s, ref): # Inline the reference by replacing the reference with the actual schema - new = all_defs.pop(ref) - ref_counts[ref] -= 1 # because we just replaced it! - new.pop('ref') # type: ignore + new = state['definitions'].pop(ref) + state['ref_counts'][ref] -= 1 # because we just replaced it! # put all other keys that were on the def-ref schema into the inlined version # in particular this is needed for `serialization` if 'serialization' in s: @@ -515,23 +546,12 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core schema = walk_core_schema(schema, inline_refs) - definitions = [d for d in all_defs.values() if ref_counts[d['ref']] > 0] # type: ignore - return make_result(schema, definitions) + definitions = [d for d in state['definitions'].values() if state['ref_counts'][d['ref']] > 0] # type: ignore - -def flatten_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: - """Simplify schema references by: - 1. Grouping all definitions into a single top-level `definitions` schema, similar to a JSON schema's `#/$defs`. - """ - return _simplify_schema_references(schema, inline=False) - - -def inline_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: - """Simplify schema references by: - 1. Inlining any definitions that are only referenced in one place and are not involved in a cycle. - 2. Removing any unused `ref` references from schemas. - """ - return _simplify_schema_references(schema, inline=True) + if definitions: + schema = core_schema.definitions_schema(schema=schema, definitions=definitions) + schema.setdefault('metadata', {})[_DEFINITIONS_CACHE_METADATA_KEY] = definitions_cache # type: ignore + return schema def pretty_print_core_schema( diff --git a/pydantic/_internal/_dataclasses.py b/pydantic/_internal/_dataclasses.py index 6d40460fdf..7efe35d888 100644 --- a/pydantic/_internal/_dataclasses.py +++ b/pydantic/_internal/_dataclasses.py @@ -16,7 +16,7 @@ from ..fields import FieldInfo from ..warnings import PydanticDeprecatedSince20 from . import _config, _decorators, _discriminated_union, _typing_extra -from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs +from ._core_utils import collect_invalid_schemas, simplify_schema_references from ._fields import collect_dataclass_fields from ._generate_schema import GenerateSchema from ._generics import get_standard_typevars_map @@ -152,19 +152,19 @@ def __init__(__dataclass_self__: PydanticDataclass, *args: Any, **kwargs: Any) - core_config = config_wrapper.core_config(cls) schema = gen_schema.collect_definitions(schema) - schema = flatten_schema_defs(schema) if collect_invalid_schemas(schema): set_dataclass_mock_validator(cls, cls.__name__, 'all referenced types') return False + schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema)) + # We are about to set all the remaining required properties expected for this cast; # __pydantic_decorators__ and __pydantic_fields__ should already be set cls = typing.cast('type[PydanticDataclass]', cls) # debug(schema) - cls.__pydantic_core_schema__ = schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema)) - simplified_core_schema = inline_schema_defs(schema) - cls.__pydantic_validator__ = validator = SchemaValidator(simplified_core_schema, core_config) - cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config) + cls.__pydantic_core_schema__ = schema + cls.__pydantic_validator__ = validator = SchemaValidator(schema, core_config) + cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) if config_wrapper.validate_assignment: diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 299b63de3c..a4b27647b4 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -52,7 +52,6 @@ from ._core_utils import ( CoreSchemaOrField, define_expected_missing_refs, - flatten_schema_defs, get_type_ref, is_list_like_schema_with_items_schema, ) @@ -518,8 +517,6 @@ def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema: def get_ref(s: CoreSchema) -> str: return s['ref'] # type: ignore - schema = flatten_schema_defs(schema) - if schema['type'] == 'definitions': self.defs.definitions.update({get_ref(s): s for s in schema['definitions']}) schema = schema['schema'] diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 925fcbd71f..fe95cacb0c 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -16,7 +16,7 @@ from ..fields import Field, FieldInfo, ModelPrivateAttr, PrivateAttr from ..warnings import PydanticDeprecatedSince20 from ._config import ConfigWrapper -from ._core_utils import collect_invalid_schemas, flatten_schema_defs, inline_schema_defs +from ._core_utils import collect_invalid_schemas, simplify_schema_references from ._decorators import ( ComputedFieldInfo, DecoratorInfos, @@ -487,16 +487,16 @@ def complete_model_class( core_config = config_wrapper.core_config(cls) schema = gen_schema.collect_definitions(schema) - schema = apply_discriminators(flatten_schema_defs(schema)) if collect_invalid_schemas(schema): set_model_mocks(cls, cls_name) return False + schema = apply_discriminators(simplify_schema_references(schema)) + # debug(schema) cls.__pydantic_core_schema__ = schema - simplified_core_schema = inline_schema_defs(schema) - cls.__pydantic_validator__ = SchemaValidator(simplified_core_schema, core_config) - cls.__pydantic_serializer__ = SchemaSerializer(simplified_core_schema, core_config) + cls.__pydantic_validator__ = SchemaValidator(schema, core_config) + cls.__pydantic_serializer__ = SchemaSerializer(schema, core_config) cls.__pydantic_complete__ = True # set __signature__ attr only for model class, but not for its instances diff --git a/pydantic/_internal/_validate_call.py b/pydantic/_internal/_validate_call.py index abb5440891..b749b062e7 100644 --- a/pydantic/_internal/_validate_call.py +++ b/pydantic/_internal/_validate_call.py @@ -10,7 +10,7 @@ from ..config import ConfigDict from . import _discriminated_union, _generate_schema, _typing_extra from ._config import ConfigWrapper -from ._core_utils import flatten_schema_defs, inline_schema_defs +from ._core_utils import simplify_schema_references @dataclass @@ -61,11 +61,12 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali namespace = _typing_extra.add_module_globals(function, None) config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) - self.__pydantic_core_schema__ = schema = gen_schema.collect_definitions(gen_schema.generate_schema(function)) + schema = gen_schema.collect_definitions(gen_schema.generate_schema(function)) + schema = simplify_schema_references(schema) + self.__pydantic_core_schema__ = schema = schema core_config = config_wrapper.core_config(self) - schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema)) - simplified_schema = inline_schema_defs(schema) - self.__pydantic_validator__ = pydantic_core.SchemaValidator(simplified_schema, core_config) + schema = _discriminated_union.apply_discriminators(schema) + self.__pydantic_validator__ = pydantic_core.SchemaValidator(schema, core_config) if self._validate_return: return_type = ( @@ -74,13 +75,11 @@ def __init__(self, function: Callable[..., Any], config: ConfigDict | None, vali else Any ) gen_schema = _generate_schema.GenerateSchema(config_wrapper, namespace) - self.__return_pydantic_core_schema__ = schema = gen_schema.collect_definitions( - gen_schema.generate_schema(return_type) - ) + schema = gen_schema.collect_definitions(gen_schema.generate_schema(return_type)) + schema = _discriminated_union.apply_discriminators(simplify_schema_references(schema)) + self.__return_pydantic_core_schema__ = schema core_config = config_wrapper.core_config(self) - schema = _discriminated_union.apply_discriminators(flatten_schema_defs(schema)) - simplified_schema = inline_schema_defs(schema) - validator = pydantic_core.SchemaValidator(simplified_schema, core_config) + validator = pydantic_core.SchemaValidator(schema, core_config) if inspect.iscoroutinefunction(self.raw_function): async def return_val_wrapper(aw: Awaitable[Any]) -> None: diff --git a/pydantic/type_adapter.py b/pydantic/type_adapter.py index 17e2eb893e..85ee6d028d 100644 --- a/pydantic/type_adapter.py +++ b/pydantic/type_adapter.py @@ -166,21 +166,20 @@ def __init__(self, type: Any, *, config: ConfigDict | None = None, _parent_depth except AttributeError: core_schema = _get_schema(type, config_wrapper, parent_depth=_parent_depth + 1) - core_schema = _discriminated_union.apply_discriminators(_core_utils.flatten_schema_defs(core_schema)) - simplified_core_schema = _core_utils.inline_schema_defs(core_schema) + core_schema = _discriminated_union.apply_discriminators(_core_utils.simplify_schema_references(core_schema)) core_config = config_wrapper.core_config(None) validator: SchemaValidator try: validator = _getattr_no_parents(type, '__pydantic_validator__') except AttributeError: - validator = SchemaValidator(simplified_core_schema, core_config) + validator = SchemaValidator(core_schema, core_config) serializer: SchemaSerializer try: serializer = _getattr_no_parents(type, '__pydantic_serializer__') except AttributeError: - serializer = SchemaSerializer(simplified_core_schema, core_config) + serializer = SchemaSerializer(core_schema, core_config) self.core_schema = core_schema self.validator = validator diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index e6f32fa0fb..02a3cb22a6 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -2070,8 +2070,8 @@ class GenericDataclass(Generic[T]): # verify that generic parameters are showing up in the type ref for generic dataclasses # this can probably be removed if the schema changes in some way that makes this part of the test fail - assert '[int:' in validator1.core_schema['schema']['schema_ref'] - assert '[str:' in validator2.core_schema['schema']['schema_ref'] + assert '[int:' in validator1.core_schema['ref'] + assert '[str:' in validator2.core_schema['ref'] assert validator1.validate_python({'x': 1}).x == 1 assert validator2.validate_python({'x': 'hello world'}).x == 'hello world' diff --git a/tests/test_internal.py b/tests/test_internal.py index 4c8c07bdd1..9053457a68 100644 --- a/tests/test_internal.py +++ b/tests/test_internal.py @@ -4,18 +4,27 @@ from dataclasses import dataclass import pytest -from pydantic_core import SchemaValidator +from pydantic_core import CoreSchema, SchemaValidator from pydantic_core import core_schema as cs -from pydantic._internal._core_utils import flatten_schema_defs, inline_schema_defs +from pydantic._internal._core_utils import Walk, simplify_schema_references, walk_core_schema from pydantic._internal._repr import Representation +def remove_metadata(schema: CoreSchema) -> CoreSchema: + def inner(s: CoreSchema, recurse: Walk) -> CoreSchema: + s = s.copy() + s.pop('metadata', None) + return recurse(s, inner) + + return walk_core_schema(schema, inner) + + @pytest.mark.parametrize( - 'input_schema,flattened,inlined', + 'input_schema,inlined', [ # Test case 1: Simple schema with no references - (cs.list_schema(cs.int_schema()), cs.list_schema(cs.int_schema()), cs.list_schema(cs.int_schema())), + (cs.list_schema(cs.int_schema()), cs.list_schema(cs.int_schema())), # Test case 2: Schema with single-level nested references ( cs.definitions_schema( @@ -25,24 +34,14 @@ cs.int_schema(ref='int'), ], ), - cs.definitions_schema( - cs.list_schema(cs.definition_reference_schema('list_of_ints')), - definitions=[ - cs.list_schema(cs.definition_reference_schema('int'), ref='list_of_ints'), - cs.int_schema(ref='int'), - ], - ), - cs.list_schema(cs.list_schema(cs.int_schema())), + cs.list_schema(cs.list_schema(cs.int_schema(ref='int'), ref='list_of_ints')), ), # Test case 3: Schema with multiple single-level nested references ( cs.list_schema( cs.definitions_schema(cs.definition_reference_schema('int'), definitions=[cs.int_schema(ref='int')]) ), - cs.definitions_schema( - cs.list_schema(cs.definition_reference_schema('int')), definitions=[cs.int_schema(ref='int')] - ), - cs.list_schema(cs.int_schema()), + cs.list_schema(cs.int_schema(ref='int')), ), # Test case 4: A simple recursive schema ( @@ -51,10 +50,6 @@ cs.definition_reference_schema(schema_ref='list'), definitions=[cs.list_schema(cs.definition_reference_schema(schema_ref='list'), ref='list')], ), - cs.definitions_schema( - cs.definition_reference_schema(schema_ref='list'), - definitions=[cs.list_schema(cs.definition_reference_schema(schema_ref='list'), ref='list')], - ), ), # Test case 5: Deeply nested schema with multiple references ( @@ -66,15 +61,11 @@ cs.int_schema(ref='int'), ], ), - cs.definitions_schema( - cs.list_schema(cs.definition_reference_schema('list_of_lists_of_ints')), - definitions=[ - cs.list_schema(cs.definition_reference_schema('list_of_ints'), ref='list_of_lists_of_ints'), - cs.list_schema(cs.definition_reference_schema('int'), ref='list_of_ints'), - cs.int_schema(ref='int'), - ], + cs.list_schema( + cs.list_schema( + cs.list_schema(cs.int_schema(ref='int'), ref='list_of_ints'), ref='list_of_lists_of_ints' + ) ), - cs.list_schema(cs.list_schema(cs.list_schema(cs.int_schema()))), ), # Test case 6: More complex recursive schema ( @@ -96,21 +87,7 @@ cs.int_schema(ref='int_or_list'), ], ), - cs.definitions_schema( - cs.list_schema(cs.definition_reference_schema(schema_ref='list_of_ints_and_lists')), - definitions=[ - cs.list_schema( - cs.definition_reference_schema(schema_ref='int_or_list'), - ref='list_of_ints_and_lists', - ), - cs.int_schema(ref='int'), - cs.tuple_variable_schema( - cs.definition_reference_schema(schema_ref='list_of_ints_and_lists'), ref='a tuple' - ), - cs.int_schema(ref='int_or_list'), - ], - ), - cs.list_schema(cs.list_schema(cs.int_schema())), + cs.list_schema(cs.list_schema(cs.int_schema(ref='int_or_list'), ref='list_of_ints_and_lists')), ), # Test case 7: Schema with multiple definitions and nested references, some of which are unused ( @@ -125,17 +102,7 @@ ) ], ), - cs.definitions_schema( - cs.list_schema(cs.definition_reference_schema('list_of_ints')), - definitions=[ - cs.list_schema( - cs.definition_reference_schema('int'), - ref='list_of_ints', - ), - cs.int_schema(ref='int'), - ], - ), - cs.list_schema(cs.list_schema(cs.int_schema())), + cs.list_schema(cs.list_schema(cs.int_schema(ref='int'), ref='list_of_ints')), ), # Test case 8: Reference is used in multiple places ( @@ -154,19 +121,7 @@ cs.definitions_schema( cs.union_schema( [ - cs.definition_reference_schema('list_of_ints'), - cs.tuple_variable_schema(cs.definition_reference_schema('int')), - ] - ), - definitions=[ - cs.list_schema(cs.definition_reference_schema('int'), ref='list_of_ints'), - cs.int_schema(ref='int'), - ], - ), - cs.definitions_schema( - cs.union_schema( - [ - cs.list_schema(cs.definition_reference_schema('int')), + cs.list_schema(cs.definition_reference_schema('int'), ref='list_of_ints'), cs.tuple_variable_schema(cs.definition_reference_schema('int')), ] ), @@ -195,23 +150,6 @@ ), ], ), - cs.definitions_schema( - cs.definition_reference_schema('model'), - definitions=[ - cs.typed_dict_schema( - { - 'a': cs.typed_dict_field( - cs.nullable_schema(cs.definition_reference_schema(schema_ref='ref')), - ), - 'b': cs.typed_dict_field( - cs.nullable_schema(cs.definition_reference_schema(schema_ref='ref')), - ), - }, - ref='model', - ), - cs.int_schema(ref='ref'), - ], - ), cs.definitions_schema( cs.typed_dict_schema( { @@ -222,6 +160,7 @@ cs.nullable_schema(cs.definition_reference_schema(schema_ref='ref')), ), }, + ref='model', ), definitions=[ cs.int_schema(ref='ref'), @@ -230,12 +169,8 @@ ), ], ) -def test_build_schema_defs(input_schema: cs.CoreSchema, flattened: cs.CoreSchema, inlined: cs.CoreSchema): - actual_flattened = flatten_schema_defs(input_schema) - assert actual_flattened == flattened - SchemaValidator(actual_flattened) # check for validity - - actual_inlined = inline_schema_defs(input_schema) +def test_build_schema_defs(input_schema: cs.CoreSchema, inlined: cs.CoreSchema): + actual_inlined = remove_metadata(simplify_schema_references(input_schema)) assert actual_inlined == inlined SchemaValidator(actual_inlined) # check for validity diff --git a/tests/test_main.py b/tests/test_main.py index dd6e1ebb75..79ece3171b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -2440,8 +2440,12 @@ def test_model_get_core_schema() -> None: class Model(BaseModel): @classmethod def __get_pydantic_core_schema__(cls, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - assert handler(int) == {'type': 'int'} - assert handler.generate_schema(int) == {'type': 'int'} + schema = handler(int) + schema.pop('metadata', None) # we don't care about this in tests + assert schema == {'type': 'int'} + schema = handler.generate_schema(int) + schema.pop('metadata', None) # we don't care about this in tests + assert schema == {'type': 'int'} return handler(source_type) Model() diff --git a/tests/test_model_signature.py b/tests/test_model_signature.py index da8ec6787b..b9587e051e 100644 --- a/tests/test_model_signature.py +++ b/tests/test_model_signature.py @@ -38,7 +38,11 @@ def test_generic_model_signature(): class Model(BaseModel, Generic[T]): a: T - sig = signature(Model[int]) + IntModel = Model[int] + + assert IntModel.model_validate({'a': '1'}).a == 1 + + sig = signature(IntModel) assert sig != signature(BaseModel) assert _equals(map(str, sig.parameters.values()), ('a: int',)) assert _equals(str(sig), '(*, a: int) -> None') diff --git a/tests/test_root_model.py b/tests/test_root_model.py index 048baeddae..9e08fd6566 100644 --- a/tests/test_root_model.py +++ b/tests/test_root_model.py @@ -47,11 +47,6 @@ def check_schema(schema: CoreSchema) -> None: # we assume the shape of the core schema here, which is not a guarantee # pydantic makes to its users but is useful to check here to make sure # we are doing the right thing internally - assert schema['type'] == 'definitions' - inner = schema['schema'] - assert inner['type'] == 'definition-ref' - ref = inner['schema_ref'] # type: ignore - schema = next(s for s in schema['definitions'] if s['ref'] == ref) # type: ignore assert schema['type'] == 'model' assert schema['root_model'] is True assert schema['custom_init'] is False