diff --git a/pydantic/_internal/_core_utils.py b/pydantic/_internal/_core_utils.py index e7701a74ba..99356885e2 100644 --- a/pydantic/_internal/_core_utils.py +++ b/pydantic/_internal/_core_utils.py @@ -32,14 +32,6 @@ _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'} -def is_definition_ref_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionReferenceSchema]: - return s['type'] == 'definition-ref' - - -def is_definitions_schema(s: core_schema.CoreSchema) -> TypeGuard[core_schema.DefinitionsSchema]: - return s['type'] == 'definitions' - - def is_core_schema( schema: CoreSchemaOrField, ) -> TypeGuard[CoreSchema]: @@ -192,14 +184,15 @@ def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchem def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: schema = self._schema_type_to_method[schema['type']](schema, f) - ser_schema: core_schema.SerSchema | None = schema.get('serialization', None) # type: ignore + ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore if ser_schema: schema['serialization'] = self._handle_ser_schemas(ser_schema.copy(), f) return schema def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema: - if 'schema' in schema: - schema['schema'] = self.walk(schema['schema'], f) # type: ignore + sub_schema = schema.get('schema', None) + if sub_schema is not None: + schema['schema'] = self.walk(sub_schema, f) # type: ignore return schema def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema: @@ -232,31 +225,36 @@ def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Wa return new_schema def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema: - if 'items_schema' in schema: - schema['items_schema'] = self.walk(schema['items_schema'], f) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) return schema def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema: - if 'items_schema' in schema: - schema['items_schema'] = self.walk(schema['items_schema'], f) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) return schema def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema: - if 'items_schema' in schema: - schema['items_schema'] = self.walk(schema['items_schema'], f) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) return schema def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema: - if 'items_schema' in schema: - schema['items_schema'] = self.walk(schema['items_schema'], f) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) return schema def handle_tuple_variable_schema( self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk ) -> core_schema.CoreSchema: schema = cast(core_schema.TupleVariableSchema, schema) - if 'items_schema' in schema: - schema['items_schema'] = self.walk(schema['items_schema'], f) + items_schema = schema.get('items_schema') + if items_schema is not None: + schema['items_schema'] = self.walk(items_schema, f) return schema def handle_tuple_positional_schema( @@ -264,15 +262,18 @@ def handle_tuple_positional_schema( ) -> core_schema.CoreSchema: schema = cast(core_schema.TuplePositionalSchema, schema) schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']] - if 'extra_schema' in schema: - schema['extra_schema'] = self.walk(schema['extra_schema'], f) + extra_schema = schema.get('extra_schema') + if extra_schema is not None: + schema['extra_schema'] = self.walk(extra_schema, f) return schema def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema: - if 'keys_schema' in schema: - schema['keys_schema'] = self.walk(schema['keys_schema'], f) - if 'values_schema' in schema: - schema['values_schema'] = self.walk(schema['values_schema'], f) + keys_schema = schema.get('keys_schema') + if keys_schema is not None: + schema['keys_schema'] = self.walk(keys_schema, f) + values_schema = schema.get('values_schema') + if values_schema: + schema['values_schema'] = self.walk(values_schema, f) return schema def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema: @@ -307,11 +308,12 @@ def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f return schema def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema: - if 'extra_validator' in schema: - schema['extra_validator'] = self.walk(schema['extra_validator'], f) + extra_validator = schema.get('extra_validator') + if extra_validator is not None: + schema['extra_validator'] = self.walk(extra_validator, f) replaced_fields: dict[str, core_schema.ModelField] = {} replaced_computed_fields: list[core_schema.ComputedField] = [] - for computed_field in schema.get('computed_fields', None) or (): + for computed_field in schema.get('computed_fields', ()): replaced_field = computed_field.copy() replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) replaced_computed_fields.append(replaced_field) @@ -325,10 +327,11 @@ def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: W return schema def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema: - if 'extra_validator' in schema: - schema['extra_validator'] = self.walk(schema['extra_validator'], f) + extra_validator = schema.get('extra_validator') + if extra_validator is not None: + schema['extra_validator'] = self.walk(extra_validator, f) replaced_computed_fields: list[core_schema.ComputedField] = [] - for computed_field in schema.get('computed_fields', None) or (): + for computed_field in schema.get('computed_fields', ()): replaced_field = computed_field.copy() replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) replaced_computed_fields.append(replaced_field) @@ -345,7 +348,7 @@ def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema: replaced_fields: list[core_schema.DataclassField] = [] replaced_computed_fields: list[core_schema.ComputedField] = [] - for computed_field in schema.get('computed_fields', None) or (): + for computed_field in schema.get('computed_fields', ()): replaced_field = computed_field.copy() replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f) replaced_computed_fields.append(replaced_field) @@ -395,12 +398,11 @@ def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.Cor Returns: core_schema.CoreSchema: A processed CoreSchema. """ - return f(schema.copy(), _dispatch) + return f(schema, _dispatch) def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901 - valid_defs: dict[str, core_schema.CoreSchema] = {} - invalid_defs: dict[str, core_schema.CoreSchema] = {} + all_defs: dict[str, core_schema.CoreSchema] = {} def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema: definitions = list(defs) @@ -413,42 +415,34 @@ def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor for definition in s['definitions']: ref = get_ref(definition) assert ref is not None - def_schema = recurse(definition, collect_refs).copy() - if 'invalid' in def_schema.get('metadata', {}): - invalid_defs[ref] = def_schema - else: - valid_defs[ref] = def_schema + all_defs[ref] = recurse(definition, collect_refs) return recurse(s['schema'], collect_refs) - ref = get_ref(s) - if ref is not None: - if 'invalid' in s.get('metadata', {}): - invalid_defs[ref] = s - else: - valid_defs[ref] = s - return recurse(s, collect_refs) + else: + ref = get_ref(s) + if ref is not None: + all_defs[ref] = s + return recurse(s, collect_refs) schema = walk_core_schema(schema, collect_refs) - all_defs = {**invalid_defs, **valid_defs} - def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: - if is_definitions_schema(s): - new: dict[str, Any] = dict(s) + if s['type'] == 'definitions': # iterate ourselves, we don't want to flatten the actual defs! - definitions: list[CoreSchema] = new.pop('definitions') - schema = cast(CoreSchema, new.pop('schema')) + definitions: list[CoreSchema] = s.pop('definitions') # type: ignore + schema: CoreSchema = s.pop('schema') # type: ignore # remaining keys are optional like 'serialization' - schema = cast(CoreSchema, {**schema, **new}) + schema: CoreSchema = {**schema, **s} # type: ignore s['schema'] = recurse(schema, flatten_refs) for definition in definitions: recurse(definition, flatten_refs) # don't re-assign here! return schema - s = recurse(s, flatten_refs) - ref = get_ref(s) - if ref and ref in all_defs: - all_defs[ref] = s - return core_schema.definition_reference_schema(schema_ref=ref) - return s + else: + s = recurse(s, flatten_refs) + ref = get_ref(s) + if ref and ref in all_defs: + all_defs[ref] = s + return core_schema.definition_reference_schema(schema_ref=ref) + return s schema = walk_core_schema(schema, flatten_refs) @@ -458,12 +452,12 @@ def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.Cor if not inline: return make_result(schema, all_defs.values()) - ref_counts: dict[str, int] = defaultdict(int) + ref_counts: defaultdict[str, int] = defaultdict(int) involved_in_recursion: dict[str, bool] = {} - current_recursion_ref_count: dict[str, int] = defaultdict(int) + current_recursion_ref_count: defaultdict[str, int] = defaultdict(int) def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema: - if not is_definition_ref_schema(s): + if s['type'] != 'definition-ref': return recurse(s, count_refs) ref = s['schema_ref'] ref_counts[ref] += 1 diff --git a/tests/benchmarks/test_fastapi_startup.py b/tests/benchmarks/test_fastapi_startup.py index 0488b873a6..843f14d9d4 100644 --- a/tests/benchmarks/test_fastapi_startup.py +++ b/tests/benchmarks/test_fastapi_startup.py @@ -122,11 +122,18 @@ def bench(): if __name__ == '__main__': - # run with `python tests/benchmarks/test_fastapi_startup.py` + # run with `pdm run tests/benchmarks/test_fastapi_startup.py` import cProfile import sys + import time INNER_DATA_MODEL_COUNT = 50 OUTER_DATA_MODEL_COUNT = 50 print(f'Python version: {sys.version}') - cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime') + if sys.argv[-1] == 'cProfile': + cProfile.run('test_fastapi_startup_perf(lambda f: f())', sort='tottime') + else: + start = time.perf_counter() + test_fastapi_startup_perf(lambda f: f()) + end = time.perf_counter() + print(f'Time taken: {end - start:.2f}s')