-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
_core_utils.py
525 lines (427 loc) · 22.8 KB
/
_core_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
from __future__ import annotations
from collections import defaultdict
from typing import Any, Callable, Hashable, Iterable, TypeVar, Union, cast
from pydantic_core import CoreSchema, core_schema
from typing_extensions import TypeAliasType, TypeGuard, get_args
from . import _repr
AnyFunctionSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
core_schema.PlainValidatorFunctionSchema,
]
FunctionSchemaWithInnerSchema = Union[
core_schema.AfterValidatorFunctionSchema,
core_schema.BeforeValidatorFunctionSchema,
core_schema.WrapValidatorFunctionSchema,
]
CoreSchemaField = Union[
core_schema.ModelField, core_schema.DataclassField, core_schema.TypedDictField, core_schema.ComputedField
]
CoreSchemaOrField = Union[core_schema.CoreSchema, CoreSchemaField]
_CORE_SCHEMA_FIELD_TYPES = {'typed-dict-field', 'dataclass-field', 'model-field', 'computed-field'}
_FUNCTION_WITH_INNER_SCHEMA_TYPES = {'function-before', 'function-after', 'function-wrap'}
_LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES = {'list', 'tuple-variable', 'set', 'frozenset'}
def is_core_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchema]:
return schema['type'] not in _CORE_SCHEMA_FIELD_TYPES
def is_core_schema_field(
schema: CoreSchemaOrField,
) -> TypeGuard[CoreSchemaField]:
return schema['type'] in _CORE_SCHEMA_FIELD_TYPES
def is_function_with_inner_schema(
schema: CoreSchemaOrField,
) -> TypeGuard[FunctionSchemaWithInnerSchema]:
return schema['type'] in _FUNCTION_WITH_INNER_SCHEMA_TYPES
def is_list_like_schema_with_items_schema(
schema: CoreSchema,
) -> TypeGuard[
core_schema.ListSchema | core_schema.TupleVariableSchema | core_schema.SetSchema | core_schema.FrozenSetSchema
]:
return schema['type'] in _LIST_LIKE_SCHEMA_WITH_ITEMS_TYPES
def get_type_ref(type_: type[Any], args_override: tuple[type[Any], ...] | None = None) -> str:
"""Produces the ref to be used for this type by pydantic_core's core schemas.
This `args_override` argument was added for the purpose of creating valid recursive references
when creating generic models without needing to create a concrete class.
"""
origin = type_
args = args_override or ()
generic_metadata = getattr(type_, '__pydantic_generic_metadata__', None)
if generic_metadata:
origin = generic_metadata['origin'] or origin
args = generic_metadata['args'] or args
module_name = getattr(origin, '__module__', '<No __module__>')
if isinstance(origin, TypeAliasType):
type_ref = f'{module_name}.{origin.__name__}'
else:
try:
qualname = getattr(origin, '__qualname__', f'<No __qualname__: {origin}>')
except Exception:
qualname = getattr(origin, '__qualname__', '<No __qualname__>')
type_ref = f'{module_name}.{qualname}:{id(origin)}'
arg_refs: list[str] = []
for arg in args:
if isinstance(arg, str):
# Handle string literals as a special case; we may be able to remove this special handling if we
# wrap them in a ForwardRef at some point.
arg_ref = f'{arg}:str-{id(arg)}'
else:
arg_ref = f'{_repr.display_as_type(arg)}:{id(arg)}'
arg_refs.append(arg_ref)
if arg_refs:
type_ref = f'{type_ref}[{",".join(arg_refs)}]'
return type_ref
def get_ref(s: core_schema.CoreSchema) -> None | str:
"""Get the ref from the schema if it has one.
This exists just for type checking to work correctly.
"""
return s.get('ref', None)
def collect_definitions(schema: core_schema.CoreSchema) -> dict[str, core_schema.CoreSchema]:
defs: dict[str, CoreSchema] = {}
def _record_valid_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref = get_ref(s)
if ref:
defs[ref] = s
return recurse(s, _record_valid_refs)
walk_core_schema(schema, _record_valid_refs)
return defs
def define_expected_missing_refs(
schema: core_schema.CoreSchema, allowed_missing_refs: set[str]
) -> core_schema.CoreSchema:
if not allowed_missing_refs:
# in this case, there are no missing refs to potentially substitute, so there's no need to walk the schema
# this is a common case (will be hit for all non-generic models), so it's worth optimizing for
return schema
refs = set()
def _record_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
ref: str | None = s.get('ref')
if ref:
refs.add(ref)
return recurse(s, _record_refs)
walk_core_schema(schema, _record_refs)
expected_missing_refs = allowed_missing_refs.difference(refs)
if expected_missing_refs:
definitions: list[core_schema.CoreSchema] = [
# TODO: Replace this with a (new) CoreSchema that, if present at any level, makes validation fail
# Issue: https://github.com/pydantic/pydantic-core/issues/619
core_schema.none_schema(ref=ref, metadata={'pydantic_debug_missing_ref': True, 'invalid': True})
for ref in expected_missing_refs
]
return core_schema.definitions_schema(schema, definitions)
return schema
def collect_invalid_schemas(schema: core_schema.CoreSchema) -> list[core_schema.CoreSchema]:
invalid_schemas: list[core_schema.CoreSchema] = []
def _is_schema_valid(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s.get('metadata', {}).get('invalid'):
invalid_schemas.append(s)
return recurse(s, _is_schema_valid)
walk_core_schema(schema, _is_schema_valid)
return invalid_schemas
T = TypeVar('T')
Recurse = Callable[[core_schema.CoreSchema, 'Walk'], core_schema.CoreSchema]
Walk = Callable[[core_schema.CoreSchema, Recurse], core_schema.CoreSchema]
# TODO: Should we move _WalkCoreSchema into pydantic_core proper?
# Issue: https://github.com/pydantic/pydantic-core/issues/615
class _WalkCoreSchema:
def __init__(self):
self._schema_type_to_method = self._build_schema_type_to_method()
def _build_schema_type_to_method(self) -> dict[core_schema.CoreSchemaType, Recurse]:
mapping: dict[core_schema.CoreSchemaType, Recurse] = {}
key: core_schema.CoreSchemaType
for key in get_args(core_schema.CoreSchemaType):
method_name = f"handle_{key.replace('-', '_')}_schema"
mapping[key] = getattr(self, method_name, self._handle_other_schemas)
return mapping
def walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
return f(schema.copy(), self._walk)
def _walk(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
schema = self._schema_type_to_method[schema['type']](schema, f)
ser_schema: core_schema.SerSchema | None = schema.get('serialization') # type: ignore
if ser_schema:
schema['serialization'] = self._handle_ser_schemas(ser_schema.copy(), f)
return schema
def _handle_other_schemas(self, schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
sub_schema = schema.get('schema', None)
if sub_schema is not None:
schema['schema'] = self.walk(sub_schema, f) # type: ignore
return schema
def _handle_ser_schemas(self, ser_schema: core_schema.SerSchema, f: Walk) -> core_schema.SerSchema:
schema: core_schema.CoreSchema | None = ser_schema.get('schema', None)
if schema is not None:
ser_schema['schema'] = self.walk(schema, f) # type: ignore
return_schema: core_schema.CoreSchema | None = ser_schema.get('return_schema', None)
if return_schema is not None:
ser_schema['return_schema'] = self.walk(return_schema, f) # type: ignore
return ser_schema
def handle_definitions_schema(self, schema: core_schema.DefinitionsSchema, f: Walk) -> core_schema.CoreSchema:
new_definitions: list[core_schema.CoreSchema] = []
for definition in schema['definitions']:
updated_definition = self.walk(definition, f)
if 'ref' in updated_definition:
# If the updated definition schema doesn't have a 'ref', it shouldn't go in the definitions
# This is most likely to happen due to replacing something with a definition reference, in
# which case it should certainly not go in the definitions list
new_definitions.append(updated_definition)
new_inner_schema = self.walk(schema['schema'], f)
if not new_definitions and len(schema) == 3:
# This means we'd be returning a "trivial" definitions schema that just wrapped the inner schema
return new_inner_schema
new_schema = schema.copy()
new_schema['schema'] = new_inner_schema
new_schema['definitions'] = new_definitions
return new_schema
def handle_list_schema(self, schema: core_schema.ListSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_set_schema(self, schema: core_schema.SetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_frozenset_schema(self, schema: core_schema.FrozenSetSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_generator_schema(self, schema: core_schema.GeneratorSchema, f: Walk) -> core_schema.CoreSchema:
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_variable_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TupleVariableSchema, schema)
items_schema = schema.get('items_schema')
if items_schema is not None:
schema['items_schema'] = self.walk(items_schema, f)
return schema
def handle_tuple_positional_schema(
self, schema: core_schema.TupleVariableSchema | core_schema.TuplePositionalSchema, f: Walk
) -> core_schema.CoreSchema:
schema = cast(core_schema.TuplePositionalSchema, schema)
schema['items_schema'] = [self.walk(v, f) for v in schema['items_schema']]
extra_schema = schema.get('extra_schema')
if extra_schema is not None:
schema['extra_schema'] = self.walk(extra_schema, f)
return schema
def handle_dict_schema(self, schema: core_schema.DictSchema, f: Walk) -> core_schema.CoreSchema:
keys_schema = schema.get('keys_schema')
if keys_schema is not None:
schema['keys_schema'] = self.walk(keys_schema, f)
values_schema = schema.get('values_schema')
if values_schema:
schema['values_schema'] = self.walk(values_schema, f)
return schema
def handle_function_schema(self, schema: AnyFunctionSchema, f: Walk) -> core_schema.CoreSchema:
if not is_function_with_inner_schema(schema):
return schema
schema['schema'] = self.walk(schema['schema'], f)
return schema
def handle_union_schema(self, schema: core_schema.UnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: list[CoreSchema | tuple[CoreSchema, str]] = []
for v in schema['choices']:
if isinstance(v, tuple):
new_choices.append((self.walk(v[0], f), v[1]))
else:
new_choices.append(self.walk(v, f))
schema['choices'] = new_choices
return schema
def handle_tagged_union_schema(self, schema: core_schema.TaggedUnionSchema, f: Walk) -> core_schema.CoreSchema:
new_choices: dict[Hashable, core_schema.CoreSchema] = {}
for k, v in schema['choices'].items():
new_choices[k] = v if isinstance(v, (str, int)) else self.walk(v, f)
schema['choices'] = new_choices
return schema
def handle_chain_schema(self, schema: core_schema.ChainSchema, f: Walk) -> core_schema.CoreSchema:
schema['steps'] = [self.walk(v, f) for v in schema['steps']]
return schema
def handle_lax_or_strict_schema(self, schema: core_schema.LaxOrStrictSchema, f: Walk) -> core_schema.CoreSchema:
schema['lax_schema'] = self.walk(schema['lax_schema'], f)
schema['strict_schema'] = self.walk(schema['strict_schema'], f)
return schema
def handle_json_or_python_schema(self, schema: core_schema.JsonOrPythonSchema, f: Walk) -> core_schema.CoreSchema:
schema['json_schema'] = self.walk(schema['json_schema'], f)
schema['python_schema'] = self.walk(schema['python_schema'], f)
return schema
def handle_model_fields_schema(self, schema: core_schema.ModelFieldsSchema, f: Walk) -> core_schema.CoreSchema:
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_fields: dict[str, core_schema.ModelField] = {}
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_typed_dict_schema(self, schema: core_schema.TypedDictSchema, f: Walk) -> core_schema.CoreSchema:
extra_validator = schema.get('extra_validator')
if extra_validator is not None:
schema['extra_validator'] = self.walk(extra_validator, f)
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
replaced_fields: dict[str, core_schema.TypedDictField] = {}
for k, v in schema['fields'].items():
replaced_field = v.copy()
replaced_field['schema'] = self.walk(v['schema'], f)
replaced_fields[k] = replaced_field
schema['fields'] = replaced_fields
return schema
def handle_dataclass_args_schema(self, schema: core_schema.DataclassArgsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_fields: list[core_schema.DataclassField] = []
replaced_computed_fields: list[core_schema.ComputedField] = []
for computed_field in schema.get('computed_fields', ()):
replaced_field = computed_field.copy()
replaced_field['return_schema'] = self.walk(computed_field['return_schema'], f)
replaced_computed_fields.append(replaced_field)
if replaced_computed_fields:
schema['computed_fields'] = replaced_computed_fields
for field in schema['fields']:
replaced_field = field.copy()
replaced_field['schema'] = self.walk(field['schema'], f)
replaced_fields.append(replaced_field)
schema['fields'] = replaced_fields
return schema
def handle_arguments_schema(self, schema: core_schema.ArgumentsSchema, f: Walk) -> core_schema.CoreSchema:
replaced_arguments_schema: list[core_schema.ArgumentsParameter] = []
for param in schema['arguments_schema']:
replaced_param = param.copy()
replaced_param['schema'] = self.walk(param['schema'], f)
replaced_arguments_schema.append(replaced_param)
schema['arguments_schema'] = replaced_arguments_schema
if 'var_args_schema' in schema:
schema['var_args_schema'] = self.walk(schema['var_args_schema'], f)
if 'var_kwargs_schema' in schema:
schema['var_kwargs_schema'] = self.walk(schema['var_kwargs_schema'], f)
return schema
def handle_call_schema(self, schema: core_schema.CallSchema, f: Walk) -> core_schema.CoreSchema:
schema['arguments_schema'] = self.walk(schema['arguments_schema'], f)
if 'return_schema' in schema:
schema['return_schema'] = self.walk(schema['return_schema'], f)
return schema
_dispatch = _WalkCoreSchema().walk
def walk_core_schema(schema: core_schema.CoreSchema, f: Walk) -> core_schema.CoreSchema:
"""Recursively traverse a CoreSchema.
Args:
schema (core_schema.CoreSchema): The CoreSchema to process, it will not be modified.
f (Walk): A function to apply. This function takes two arguments:
1. The current CoreSchema that is being processed
(not the same one you passed into this function, one level down).
2. The "next" `f` to call. This lets you for example use `f=functools.partial(some_method, some_context)`
to pass data down the recursive calls without using globals or other mutable state.
Returns:
core_schema.CoreSchema: A processed CoreSchema.
"""
return f(schema, _dispatch)
def _simplify_schema_references(schema: core_schema.CoreSchema, inline: bool) -> core_schema.CoreSchema: # noqa: C901
all_defs: dict[str, core_schema.CoreSchema] = {}
def make_result(schema: core_schema.CoreSchema, defs: Iterable[core_schema.CoreSchema]) -> core_schema.CoreSchema:
definitions = list(defs)
if definitions:
return core_schema.definitions_schema(schema=schema, definitions=definitions)
return schema
def collect_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
for definition in s['definitions']:
ref = get_ref(definition)
assert ref is not None
all_defs[ref] = recurse(definition, collect_refs)
return recurse(s['schema'], collect_refs)
else:
ref = get_ref(s)
if ref is not None:
all_defs[ref] = s
return recurse(s, collect_refs)
schema = walk_core_schema(schema, collect_refs)
def flatten_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definitions':
# iterate ourselves, we don't want to flatten the actual defs!
definitions: list[CoreSchema] = s.pop('definitions') # type: ignore
schema: CoreSchema = s.pop('schema') # type: ignore
# remaining keys are optional like 'serialization'
schema: CoreSchema = {**schema, **s} # type: ignore
s['schema'] = recurse(schema, flatten_refs)
for definition in definitions:
recurse(definition, flatten_refs) # don't re-assign here!
return schema
else:
s = recurse(s, flatten_refs)
ref = get_ref(s)
if ref and ref in all_defs:
all_defs[ref] = s
return core_schema.definition_reference_schema(schema_ref=ref)
return s
schema = walk_core_schema(schema, flatten_refs)
for def_schema in all_defs.values():
walk_core_schema(def_schema, flatten_refs)
if not inline:
return make_result(schema, all_defs.values())
ref_counts: defaultdict[str, int] = defaultdict(int)
involved_in_recursion: dict[str, bool] = {}
current_recursion_ref_count: defaultdict[str, int] = defaultdict(int)
def count_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] != 'definition-ref':
return recurse(s, count_refs)
ref = s['schema_ref']
ref_counts[ref] += 1
if ref_counts[ref] >= 2:
# If this model is involved in a recursion this should be detected
# on its second encounter, we can safely stop the walk here.
if current_recursion_ref_count[ref] != 0:
involved_in_recursion[ref] = True
return s
current_recursion_ref_count[ref] += 1
recurse(all_defs[ref], count_refs)
current_recursion_ref_count[ref] -= 1
return s
schema = walk_core_schema(schema, count_refs)
assert all(c == 0 for c in current_recursion_ref_count.values()), 'this is a bug! please report it'
def inline_refs(s: core_schema.CoreSchema, recurse: Recurse) -> core_schema.CoreSchema:
if s['type'] == 'definition-ref':
ref = s['schema_ref']
# Check if the reference is only used once and not involved in recursion
if ref_counts[ref] <= 1 and not involved_in_recursion.get(ref, False):
# Inline the reference by replacing the reference with the actual schema
new = all_defs.pop(ref)
ref_counts[ref] -= 1 # because we just replaced it!
new.pop('ref') # type: ignore
# put all other keys that were on the def-ref schema into the inlined version
# in particular this is needed for `serialization`
if 'serialization' in s:
new['serialization'] = s['serialization']
s = recurse(new, inline_refs)
return s
else:
return recurse(s, inline_refs)
else:
return recurse(s, inline_refs)
schema = walk_core_schema(schema, inline_refs)
definitions = [d for d in all_defs.values() if ref_counts[d['ref']] > 0] # type: ignore
return make_result(schema, definitions)
def flatten_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Grouping all definitions into a single top-level `definitions` schema, similar to a JSON schema's `#/$defs`.
"""
return _simplify_schema_references(schema, inline=False)
def inline_schema_defs(schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
"""Simplify schema references by:
1. Inlining any definitions that are only referenced in one place and are not involved in a cycle.
2. Removing any unused `ref` references from schemas.
"""
return _simplify_schema_references(schema, inline=True)