New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix nested discriminated union schema gen, pt 2 #8932
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ | |
from ..errors import PydanticUserError | ||
from . import _core_utils | ||
from ._core_utils import ( | ||
NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, | ||
CoreSchemaField, | ||
collect_definitions, | ||
simplify_schema_references, | ||
|
@@ -29,7 +28,7 @@ def __init__(self, ref: str) -> None: | |
super().__init__(f'Missing definition for ref {self.ref!r}') | ||
|
||
|
||
def set_discriminator(schema: CoreSchema, discriminator: Any) -> None: | ||
def set_discriminator_in_metadata(schema: CoreSchema, discriminator: Any) -> None: | ||
schema.setdefault('metadata', {}) | ||
metadata = schema.get('metadata') | ||
assert metadata is not None | ||
|
@@ -41,25 +40,16 @@ def apply_discriminators(schema: core_schema.CoreSchema) -> core_schema.CoreSche | |
|
||
def inner(s: core_schema.CoreSchema, recurse: _core_utils.Recurse) -> core_schema.CoreSchema: | ||
nonlocal definitions | ||
if 'metadata' in s: | ||
if s['metadata'].get(NEEDS_APPLY_DISCRIMINATED_UNION_METADATA_KEY, True) is False: | ||
return s | ||
|
||
s = recurse(s, inner) | ||
if s['type'] == 'tagged-union': | ||
return s | ||
|
||
metadata = s.get('metadata', {}) | ||
discriminator = metadata.get(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) | ||
discriminator = metadata.pop(CORE_SCHEMA_METADATA_DISCRIMINATOR_PLACEHOLDER_KEY, None) | ||
if discriminator is not None: | ||
if definitions is None: | ||
definitions = collect_definitions(schema) | ||
# After we collect the definitions schemas, we must run through the discriminator | ||
# application logic for each one. This step is crucial to prevent an exponential | ||
# increase in complexity that occurs if schemas are left as 'union' schemas | ||
# rather than 'tagged-union' schemas. | ||
# For more details, see https://github.com/pydantic/pydantic/pull/8904#discussion_r1504687302 | ||
definitions = {k: recurse(v, inner) for k, v in definitions.items()} | ||
s = apply_discriminator(s, discriminator, definitions) | ||
return s | ||
|
||
|
@@ -274,6 +264,10 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None: | |
* Validating that each allowed discriminator value maps to a unique choice | ||
* Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. | ||
""" | ||
if choice['type'] == 'definition-ref': | ||
if choice['schema_ref'] not in self.definitions: | ||
raise MissingDefinitionForUnionRef(choice['schema_ref']) | ||
|
||
if choice['type'] == 'none': | ||
self._should_be_nullable = True | ||
elif choice['type'] == 'definitions': | ||
|
@@ -285,17 +279,14 @@ def _handle_choice(self, choice: core_schema.CoreSchema) -> None: | |
# Reverse the choices list before extending the stack so that they get handled in the order they occur | ||
choices_schemas = [v[0] if isinstance(v, tuple) else v for v in choice['choices'][::-1]] | ||
self._choices_to_handle.extend(choices_schemas) | ||
elif choice['type'] == 'definition-ref': | ||
if choice['schema_ref'] not in self.definitions: | ||
raise MissingDefinitionForUnionRef(choice['schema_ref']) | ||
self._handle_choice(self.definitions[choice['schema_ref']]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Important!!! If a choice is of type Our schema walking logic walks through both the schema and the definitions, so we can rest easy knowing that unions will be converted to tagged unions in the definitions list as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this work with JSON schema generation if e.g. the ref'ed schema is itself a discriminated union? Maybe that can't happen, and either way this seems like an improvement if no tests fail, but still There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good question. I believe that the walk core schema logic handles definitions schemas such that the function being applied during the walk is also applied to all of the definitions in a definitions schema, so that's why I felt comfortable removing this step. This can be seen via the example test I added - discriminated union transformation logic is applied to the 2 schemas in the definitions list that require said changes! |
||
elif choice['type'] not in { | ||
'model', | ||
'typed-dict', | ||
'tagged-union', | ||
'lax-or-strict', | ||
'dataclass', | ||
'dataclass-args', | ||
'definition-ref', | ||
} and not _core_utils.is_function_with_inner_schema(choice): | ||
# We should eventually handle 'definition-ref' as well | ||
raise TypeError( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,10 +7,11 @@ | |
import pytest | ||
from dirty_equals import HasRepr, IsStr | ||
from pydantic_core import SchemaValidator, core_schema | ||
from typing_extensions import Annotated, Literal | ||
from typing_extensions import Annotated, Literal, TypedDict | ||
|
||
from pydantic import BaseModel, ConfigDict, Discriminator, Field, TypeAdapter, ValidationError, field_validator | ||
from pydantic._internal._discriminated_union import apply_discriminator | ||
from pydantic.dataclasses import dataclass as pydantic_dataclass | ||
from pydantic.errors import PydanticUserError | ||
from pydantic.fields import FieldInfo | ||
from pydantic.json_schema import GenerateJsonSchema | ||
|
@@ -1868,11 +1869,93 @@ class LeafState(BaseModel): | |
state_type: Literal['leaf'] | ||
|
||
AnyState = Annotated[Union[NestedState, LoopState, LeafState], Field(..., discriminator='state_type')] | ||
NestedState.model_rebuild() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you expand this section, you can see that this test showcases the example I've shown in the PR description 👍 |
||
LoopState.model_rebuild() | ||
adapter = TypeAdapter(AnyState) | ||
|
||
assert adapter.core_schema['schema']['type'] == 'tagged-union' | ||
for definition in adapter.core_schema['definitions']: | ||
if definition['schema']['model_name'] in ['NestedState', 'LoopState']: | ||
assert definition['schema']['fields']['substate']['schema']['schema']['type'] == 'tagged-union' | ||
|
||
|
||
def test_recursive_discriminiated_union_with_typed_dict() -> None: | ||
class Foo(TypedDict): | ||
type: Literal['foo'] | ||
x: 'Foobar' | ||
|
||
class Bar(TypedDict): | ||
type: Literal['bar'] | ||
|
||
Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')] | ||
ta = TypeAdapter(Foobar) | ||
|
||
# len of errors should be 1 for each case, bc we're using a tagged union | ||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'wrong'}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
core_schema = ta.core_schema | ||
assert core_schema['schema']['type'] == 'tagged-union' | ||
for definition in core_schema['definitions']: | ||
if 'Foo' in definition['ref']: | ||
assert definition['fields']['x']['schema']['type'] == 'tagged-union' | ||
|
||
|
||
def test_recursive_discriminiated_union_with_base_model() -> None: | ||
class Foo(BaseModel): | ||
type: Literal['foo'] | ||
x: 'Foobar' | ||
|
||
class Bar(BaseModel): | ||
type: Literal['bar'] | ||
|
||
Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')] | ||
ta = TypeAdapter(Foobar) | ||
|
||
# len of errors should be 1 for each case, bc we're using a tagged union | ||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'wrong'}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
core_schema = ta.core_schema | ||
assert core_schema['schema']['type'] == 'tagged-union' | ||
for definition in core_schema['definitions']: | ||
if 'Foo' in definition['ref']: | ||
assert definition['schema']['fields']['x']['schema']['type'] == 'tagged-union' | ||
|
||
|
||
def test_recursive_discriminated_union_with_pydantic_dataclass() -> None: | ||
@pydantic_dataclass | ||
class Foo: | ||
type: Literal['foo'] | ||
x: 'Foobar' | ||
|
||
@pydantic_dataclass | ||
class Bar: | ||
type: Literal['bar'] | ||
|
||
Foobar = Annotated[Union[Foo, Bar], Field(discriminator='type')] | ||
ta = TypeAdapter(Foobar) | ||
|
||
# len of errors should be 1 for each case, bc we're using a tagged union | ||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'wrong'}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
with pytest.raises(ValidationError) as e: | ||
ta.validate_python({'type': 'foo', 'x': {'type': 'wrong'}}) | ||
assert len(e.value.errors()) == 1 | ||
|
||
core_schema = ta.core_schema | ||
assert core_schema['schema']['type'] == 'tagged-union' | ||
for definition in core_schema['definitions']: | ||
if 'Foo' in definition['ref']: | ||
for field in definition['schema']['fields']: | ||
assert field['schema']['type'] == 'tagged-union' if field['name'] == 'x' else True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually get rid of this once we use it