Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nested discriminated union schema gen, pt 2 #8932

Merged
merged 3 commits into from Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 0 additions & 5 deletions pydantic/_internal/_core_utils.py
Expand Up @@ -42,11 +42,6 @@

_DEFINITIONS_CACHE_METADATA_KEY = 'pydantic.definitions_cache'

NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY = 'pydantic.internal.needs_apply_discriminated_union'
"""Used to mark a schema that has a discriminated union that needs to be checked for validity at the end of
schema building because one of it's members refers to a definition that was not yet defined when the union
was first encountered.
"""
TAGGED_UNION_TAG_KEY = 'pydantic.internal.tagged_union_tag'
"""
Used in a `Tag` schema to specify the tag used for a discriminated union.
Expand Down
23 changes: 7 additions & 16 deletions pydantic/_internal/_discriminated_union.py
Expand Up @@ -7,7 +7,6 @@
from ..errors import PydanticUserError
from . import _core_utils
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaField,
collect_definitions,
simplify_schema_references,
Expand All @@ -29,7 +28,7 @@ def __init__(self, ref: str) -> None:
super().__init__(f'Missing definition for ref {self.ref!r}')


def set_discriminator(schema: CoreSchema, discriminator: Any) -> None:
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None:
schema.setdefault('metadata', {})
metadata = schema.get('metadata')
assert metadata is not None
Expand All @@ -41,25 +40,16 @@ def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSche

def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema:
nonlocal definitions
if 'metadata' in s:
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False:
return s

s = recurse(s, inner)
if s['type'] == 'tagged-union':
return s

metadata = s.get('metadata', {})
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually get rid of this once we use it

if discriminator is not None:
if definitions is None:
definitions = collect_definitions(schema)
# After we collect the definitions schemas, we must run through the discriminator
# application logic for each one. This step is crucial to prevent an exponential
# increase in complexity that occurs if schemas are left as 'union' schemas
# rather than 'tagged-union' schemas.
# For more details, see https://github.com/pydantic/pydantic/pull/8904#discussion_r1504687302
definitions = {k: recurse(v, inner) for k, v in definitions.items()}
s = apply_discriminator(s, discriminator, definitions)
return s

Expand Down Expand Up @@ -274,6 +264,10 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
* Validating that each allowed discriminator value maps to a unique choice
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema.
"""
if choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])

if choice['type'] == 'none':
self._should_be_nullable = True
elif choice['type'] == 'definitions':
Expand All @@ -285,17 +279,14 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None:
# Reverse the choices list before extending the stack so that they get handled in the order they occur
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]]
self._choices_to_handle.extend(choices_schemas)
elif choice['type'] == 'definition-ref':
if choice['schema_ref'] not in self.definitions:
raise MissingDefinitionForUnionRef(choice['schema_ref'])
self._handle_choice(self.definitions[choice['schema_ref']])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important!!! If a choice is of type definition-ref, we want to just reuse that ref for the given choice. Before, we were going through this whole thing of fetching the value from definitions, then using that, but that ends up not working for nested / recursive schemas.

Our schema walking logic walks through both the schema and the definitions, so we can rest easy knowing that unions will be converted to tagged unions in the definitions list as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with JSON schema generation if e.g. the ref'ed schema is itself a discriminated union? Maybe that can't happen, and either way this seems like an improvement if no tests fail, but still

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I believe that the walk core schema logic handles definitions schemas such that the function being applied during the walk is also applied to all of the definitions in a definitions schema, so that's why I felt comfortable removing this step.

This can be seen via the example test I added - discriminated union transformation logic is applied to the 2 schemas in the definitions list that require said changes!

elif choice['type'] not in {
'model',
'typed-dict',
'tagged-union',
'lax-or-strict',
'dataclass',
'dataclass-args',
'definition-ref',
} and not _core_utils.is_function_with_inner_schema(choice):
# We should eventually handle 'definition-ref' as well
raise TypeError(
Expand Down
25 changes: 4 additions & 21 deletions pydantic/_internal/_generate_schema.py
Expand Up @@ -48,7 +48,6 @@
from ._config import ConfigWrapper, ConfigWrapperStack
from ._core_metadata import CoreMetadataHandler, build_metadata_dict
from ._core_utils import (
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY,
CoreSchemaOrField,
collect_invalid_schemas,
define_expected_missing_refs,
Expand Down Expand Up @@ -312,7 +311,6 @@ class GenerateSchema:
'_config_wrapper_stack',
'_types_namespace_stack',
'_typevars_map',
'_needs_apply_discriminated_union',
'_has_invalid_schema',
'field_name_stack',
'defs',
Expand All @@ -328,7 +326,6 @@ def __init__(
self._config_wrapper_stack = ConfigWrapperStack(config_wrapper)
self._types_namespace_stack = TypesNamespaceStack(types_namespace)
self._typevars_map = typevars_map
self._needs_apply_discriminated_union = False
self._has_invalid_schema = False
self.field_name_stack = _FieldNameStack()
self.defs = _Definitions()
Expand All @@ -345,7 +342,6 @@ def __from_parent(
obj._config_wrapper_stack = config_wrapper_stack
obj._types_namespace_stack = types_namespace_stack
obj._typevars_map = typevars_map
obj._needs_apply_discriminated_union = False
obj._has_invalid_schema = False
obj.field_name_stack = _FieldNameStack()
obj.defs = defs
Expand Down Expand Up @@ -426,15 +422,10 @@ def _apply_discriminator_to_union(
)
except _discriminated_union.MissingDefinitionForUnionRef:
# defer until defs are resolved
_discriminated_union.set_discriminator(
_discriminated_union.set_discriminator_in_metadata(
schema,
discriminator,
)
if 'metadata' in schema:
schema['metadata'][NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = True
else:
schema['metadata'] = {NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: True}
self._needs_apply_discriminated_union = True
return schema

class CollectedInvalid(Exception):
Expand Down Expand Up @@ -736,24 +727,16 @@ def _get_first_two_args_or_any(self, obj: Any) -> tuple[Any, Any]:
return args[0], args[1]

def _post_process_generated_schema(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema:
if 'metadata' in schema:
metadata = schema['metadata']
metadata[NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY] = self._needs_apply_discriminated_union
else:
schema['metadata'] = {
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY: self._needs_apply_discriminated_union,
}
if 'metadata' not in schema:
schema['metadata'] = {}
return schema

def _generate_schema(self, obj: Any) -> core_schema.CoreSchema:
"""Recursively generate a pydantic-core schema for any supported python type."""
has_invalid_schema = self._has_invalid_schema
self._has_invalid_schema = False
needs_apply_discriminated_union = self._needs_apply_discriminated_union
self._needs_apply_discriminated_union = False
schema = self._post_process_generated_schema(self._generate_schema_inner(obj))
schema = self._generate_schema_inner(obj)
self._has_invalid_schema = self._has_invalid_schema or has_invalid_schema
self._needs_apply_discriminated_union = self._needs_apply_discriminated_union or needs_apply_discriminated_union
return schema

def _generate_schema_inner(self, obj: Any) -> core_schema.CoreSchema:
Expand Down
89 changes: 86 additions & 3 deletions tests/test_discriminated_union.py
Expand Up @@ -7,10 +7,11 @@
import pytest
from dirty_equals import HasRepr, IsStr
from pydantic_core import SchemaValidator, core_schema
from typing_extensions import Annotated, Literal
from typing_extensions import Annotated, Literal, TypedDict

from pydantic import BaseModel, ConfigDict, Discriminator, Field, TypeAdapter, ValidationError, field_validator
from pydantic._internal._discriminated_union import apply_discriminator
from pydantic.dataclasses import dataclass as pydantic_dataclass
from pydantic.errors import PydanticUserError
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema
Expand Down Expand Up @@ -1868,11 +1869,93 @@ class LeafState(BaseModel):
state_type: Literal['leaf']

AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')]
NestedState.model_rebuild()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you expand this section, you can see that this test showcases the example I've shown in the PR description 👍

LoopState.model_rebuild()
adapter = TypeAdapter(AnyState)

assert adapter.core_schema['schema']['type'] == 'tagged-union'
for definition in adapter.core_schema['definitions']:
if definition['schema']['model_name'] in ['NestedState', 'LoopState']:
assert definition['schema']['fields']['substate']['schema']['schema']['type'] == 'tagged-union'


def test_recursive_discriminiated_union_with_typed_dict() -> None:
class Foo(TypedDict):
type: Literal['foo']
x: 'Foobar'

class Bar(TypedDict):
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
assert definition['fields']['x']['schema']['type'] == 'tagged-union'


def test_recursive_discriminiated_union_with_base_model() -> None:
class Foo(BaseModel):
type: Literal['foo']
x: 'Foobar'

class Bar(BaseModel):
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
assert definition['schema']['fields']['x']['schema']['type'] == 'tagged-union'


def test_recursive_discriminated_union_with_pydantic_dataclass() -> None:
@pydantic_dataclass
class Foo:
type: Literal['foo']
x: 'Foobar'

@pydantic_dataclass
class Bar:
type: Literal['bar']

Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')]
ta = TypeAdapter(Foobar)

# len of errors should be 1 for each case, bc we're using a tagged union
with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'wrong'})
assert len(e.value.errors()) == 1

with pytest.raises(ValidationError) as e:
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}})
assert len(e.value.errors()) == 1

core_schema = ta.core_schema
assert core_schema['schema']['type'] == 'tagged-union'
for definition in core_schema['definitions']:
if 'Foo' in definition['ref']:
for field in definition['schema']['fields']:
assert field['schema']['type'] == 'tagged-union' if field['name'] == 'x' else True