Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove schema building caches #7624

Merged
merged 2 commits into from Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
163 changes: 59 additions & 104 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -14,7 +14,7 @@

from pydantic_core import CoreSchema, core_schema
from pydantic_core import validate_core_schema as _validate_core_schema
from typing_extensions import TypeAliasType, TypedDict, TypeGuard, get_args
from typing_extensions import TypeAliasType, TypeGuard, get_args

from . import _repr

Expand Down Expand Up @@ -128,12 +128,6 @@ def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema
defs: dict[str, CoreSchema] = {}

def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
defs.update(definitions_cache['definitions'])
return s

ref = get_ref(s)
if ref:
defs[ref] = s
Expand Down Expand Up @@ -215,7 +209,7 @@ def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchem
return f(schema, self._walk)

def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema, f)
schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema, f)
Expand Down Expand Up @@ -436,101 +430,62 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema, _dispatch)


class _DefinitionsState(TypedDict):
definitions: dict[str, core_schema.CoreSchema]
ref_counts: dict[str, int]
involved_in_recursion: dict[str, bool]
current_recursion_ref_count: dict[str, int]
return f(schema.copy(), _dispatch)


def simplify_schema_references(schema: core_schema.CoreSchema) -> core_schema.CoreSchema: # noqa: C901
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
state = _DefinitionsState(
definitions={},
ref_counts=defaultdict(int),
involved_in_recursion={},
current_recursion_ref_count=defaultdict(int),
)
definitions: dict[str, core_schema.CoreSchema] = {}
ref_counts: dict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: dict[str, int] = defaultdict(int)

def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
state['definitions'].update(definitions_cache['definitions'])
return s

if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
state['definitions'][ref] = definition
definitions[ref] = definition
recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
state['definitions'][ref] = s
recurse(s, collect_refs)
new = recurse(s, collect_refs)
new_ref = get_ref(new)
if new_ref:
definitions[new_ref] = new
return core_schema.definition_reference_schema(schema_ref=ref)
else:
return recurse(s, collect_refs)

schema = walk_core_schema(schema, collect_refs)

def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if 'metadata' in s:
definitions_cache: _DefinitionsState | None = s['metadata'].get(_DEFINITIONS_CACHE_METADATA_KEY, None)
if definitions_cache is not None:
for ref in definitions_cache['ref_counts']:
state['ref_counts'][ref] += definitions_cache['ref_counts'][ref]
# it's possible that a schema was seen before we hit the cache
# and also exists in the cache, in which case it is involved in a recursion
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
# if it's involved in recursion in the inner schema mark it globally as involved in a recursion
for ref_in_recursion in definitions_cache['involved_in_recursion']:
if ref_in_recursion:
state['involved_in_recursion'][ref_in_recursion] = True
return s

if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
state['ref_counts'][ref] += 1
ref_counts[ref] += 1

if state['ref_counts'][ref] >= 2:
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if state['current_recursion_ref_count'][ref] != 0:
state['involved_in_recursion'][ref] = True
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s

state['current_recursion_ref_count'][ref] += 1
recurse(state['definitions'][ref], count_refs)
state['current_recursion_ref_count'][ref] -= 1
current_recursion_ref_count[ref] += 1
recurse(definitions[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s

schema = walk_core_schema(schema, count_refs)

assert all(c == 0 for c in state['current_recursion_ref_count'].values()), 'this is a bug! please report it'

definitions_cache = _DefinitionsState(
definitions=state['definitions'],
ref_counts=dict(state['ref_counts']),
involved_in_recursion=state['involved_in_recursion'],
current_recursion_ref_count=dict(state['current_recursion_ref_count']),
)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'

def can_be_inlined(s: core_schema.DefinitionReferenceSchema, ref: str) -> bool:
if state['ref_counts'][ref] > 1:
if ref_counts[ref] > 1:
return False
if state['involved_in_recursion'].get(ref, False):
if involved_in_recursion.get(ref, False):
return False
if 'serialization' in s:
return False
Expand All @@ -553,8 +508,8 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core
# any extra keys (like 'serialization')
if can_be_inlined(s, ref):
# Inline the reference by replacing the reference with the actual schema
new = state['definitions'].pop(ref)
state['ref_counts'][ref] -= 1 # because we just replaced it!
new = definitions.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
Expand All @@ -568,17 +523,44 @@ def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Core

schema = walk_core_schema(schema, inline_refs)

definitions = [d for d in state['definitions'].values() if state['ref_counts'][d['ref']] > 0] # type: ignore
def_values = [v for v in definitions.values() if ref_counts[v['ref']] > 0] # type: ignore

if definitions:
schema = core_schema.definitions_schema(schema=schema, definitions=definitions)
if 'metadata' in schema:
schema['metadata'][_DEFINITIONS_CACHE_METADATA_KEY] = definitions_cache
else:
schema['metadata'] = {_DEFINITIONS_CACHE_METADATA_KEY: definitions_cache}
if def_values:
schema = core_schema.definitions_schema(schema=schema, definitions=def_values)
return schema


def _strip_metadata(schema: CoreSchema) -> CoreSchema:
def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s = s.copy()
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config', None)

return recurse(s, strip_metadata)

return walk_core_schema(schema, strip_metadata)


def pretty_print_core_schema(
schema: CoreSchema,
include_metadata: bool = False,
Expand All @@ -593,34 +575,7 @@ def pretty_print_core_schema(
from rich import print # type: ignore # install it manually in your dev env

if not include_metadata:

def strip_metadata(s: CoreSchema, recurse: Recurse) -> CoreSchema:
s.pop('metadata', None)
if s['type'] == 'model-fields':
s = s.copy()
s['fields'] = {k: v.copy() for k, v in s['fields'].items()}
for field_name, field_schema in s['fields'].items():
field_schema.pop('metadata', None)
s['fields'][field_name] = field_schema
computed_fields = s.get('computed_fields', None)
if computed_fields:
s['computed_fields'] = [cf.copy() for cf in computed_fields]
for cf in computed_fields:
cf.pop('metadata', None)
else:
s.pop('computed_fields', None)
elif s['type'] == 'model':
# remove some defaults
if s.get('custom_init', True) is False:
s.pop('custom_init')
if s.get('root_model', True) is False:
s.pop('root_model')
if {'title'}.issuperset(s.get('config', {}).keys()):
s.pop('config')

return recurse(s, strip_metadata)

schema = walk_core_schema(schema, strip_metadata)
schema = _strip_metadata(schema)

return print(schema)

Expand Down
20 changes: 4 additions & 16 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -54,6 +54,7 @@
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
define_expected_missing_refs,
get_ref,
get_type_ref,
is_list_like_schema_with_items_schema,
)
Expand Down Expand Up @@ -396,6 +397,8 @@ def collect_definitions(self, schema: CoreSchema) -> CoreSchema:
ref = cast('str | None', schema.get('ref', None))
if ref:
self.defs.definitions[ref] = schema
if 'ref' in schema:
schema = core_schema.definition_reference_schema(schema['ref'])
return core_schema.definitions_schema(
schema,
list(self.defs.definitions.values()),
Expand Down Expand Up @@ -554,19 +557,6 @@ def _model_schema(self, cls: type[BaseModel]) -> core_schema.CoreSchema:
self.defs.definitions[model_ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(model_ref)

def _unpack_refs_defs(self, schema: CoreSchema) -> CoreSchema:
"""Unpack all 'definitions' schemas into `GenerateSchema.defs.definitions`
and return the inner schema.
"""

def get_ref(s: CoreSchema) -> str:
return s['ref'] # type: ignore

if schema['type'] == 'definitions':
self.defs.definitions.update({get_ref(s): s for s in schema['definitions']})
schema = schema['schema']
return schema

def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.CoreSchema | None:
"""Try to generate schema from either the `__get_pydantic_core_schema__` function or
`__pydantic_core_schema__` property.
Expand Down Expand Up @@ -603,9 +593,7 @@ def _generate_schema_from_property(self, obj: Any, source: Any) -> core_schema.C
source, CallbackGetCoreSchemaHandler(self._generate_schema, self, ref_mode=ref_mode)
)

schema = self._unpack_refs_defs(schema)

ref: str | None = schema.get('ref', None)
ref = get_ref(schema)
if ref:
self.defs.definitions[ref] = self._post_process_generated_schema(schema)
return core_schema.definition_reference_schema(ref)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_edge_cases.py
Expand Up @@ -2612,3 +2612,25 @@ def __exit__(self, _exception_type, exception, exception_traceback):
MyModel(**data)

assert len(traceback_exceptions) == 1


def test_recursive_walk_fails_on_double_diamond_composition():
class A(BaseModel):
pass

class B(BaseModel):
a_1: A
a_2: A

class C(BaseModel):
b: B

class D(BaseModel):
c_1: C
c_2: C

class E(BaseModel):
c: C

# This is just to check that above model contraption doesn't fail
assert E(c=C(b=B(a_1=A(), a_2=A()))).model_dump() == {'c': {'b': {'a_1': {}, 'a_2': {}}}}