Skip to content

Commit

Permalink
Fix schema references in discriminated unions (#7646)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Sep 26, 2023
1 parent 734f7ad commit b51804e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 33 deletions.
13 changes: 11 additions & 2 deletions pydantic/_internal/_discriminated_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, CoreSchemaField, collect_definitions
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaField,
collect_definitions,
simplify_schema_references,
)

CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY = 'pydantic.internal.union_discriminator'

Expand Down Expand Up @@ -49,7 +54,7 @@ def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schem
s = apply_discriminator(s, discriminator, definitions)
return s

return _core_utils.walk_core_schema(schema, inner)
return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))


def apply_discriminator(
Expand Down Expand Up @@ -177,6 +182,10 @@ def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
if self._should_be_nullable and not self._is_nullable:
schema = core_schema.nullable_schema(schema)
self._used = True
new_defs = collect_definitions(schema)
missing_defs = self.definitions.keys() - new_defs.keys()
if missing_defs:
schema = core_schema.definitions_schema(schema, [self.definitions[ref] for ref in missing_defs])
return schema

def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
Expand Down
68 changes: 37 additions & 31 deletions tests/test_discriminated_union.py
Original file line number Diff line number Diff line change
Expand Up @@ -955,54 +955,60 @@ def test_lax_or_strict_definitions() -> None:
discriminated_schema = apply_discriminator(core_schema.union_schema([cat, dog]), 'kind')
# insert_assert(discriminated_schema)
assert discriminated_schema == {
'type': 'tagged-union',
'choices': {
'cat': {
'type': 'typed-dict',
'fields': {'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['cat']}}},
},
'DOG': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'definitions',
'schema': {
'type': 'tagged-union',
'choices': {
'cat': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['cat']}}
},
},
'strict_schema': {
'type': 'definitions',
'schema': {
'DOG': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
},
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
},
'dog': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
},
'strict_schema': {
'type': 'definitions',
'schema': {
'dog': {
'type': 'lax-or-strict',
'lax_schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['DOG']}}
},
},
'strict_schema': {
'type': 'definitions',
'schema': {
'type': 'typed-dict',
'fields': {
'kind': {'type': 'typed-dict-field', 'schema': {'type': 'literal', 'expected': ['dog']}}
},
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
'definitions': [{'type': 'int', 'ref': 'my-int-definition'}],
},
},
'discriminator': 'kind',
'strict': False,
'from_attributes': True,
},
'discriminator': 'kind',
'strict': False,
'from_attributes': True,
'definitions': [{'type': 'str', 'ref': 'my-str-definition'}],
}


Expand Down

0 comments on commit b51804e

Please sign in to comment.