Skip to content

Commit

Permalink
Add discriminated union support
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 23, 2023
1 parent 73373c3 commit ecc5b85
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 49 deletions.
92 changes: 90 additions & 2 deletions pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import re
import typing
import warnings
from enum import Enum
from typing import TYPE_CHECKING, Any

from annotated_types import BaseMetadata, GroupedMetadata
from pydantic_core import SchemaError, SchemaValidator, core_schema
from typing_extensions import Annotated, Literal, get_args, get_origin, is_typeddict

from ..errors import PydanticSchemaGenerationError
from ..errors import PydanticSchemaGenerationError, PydanticUserError
from ..fields import FieldInfo
from ..json_schema import JsonSchemaMetadata, JsonSchemaValue
from . import _fields, _typing_extra
Expand Down Expand Up @@ -227,6 +228,8 @@ def generate_field_schema(
"""
assert field_info.annotation is not None, 'field_info.annotation should not be None when generating a schema'
schema = self.generate_schema(field_info.annotation)
if field_info.discriminator is not None:
schema = apply_discriminator(schema, field_info.discriminator)
schema = apply_annotations(schema, field_info.metadata)

if not field_info.is_required():
Expand Down Expand Up @@ -278,7 +281,8 @@ def _annotated_schema(self, annotated_type: Any) -> core_schema.CoreSchema:
"""
first_arg, *other_args = get_args(annotated_type)
schema = self.generate_schema(first_arg)
return apply_annotations(schema, other_args)
schema = apply_annotations(schema, other_args)
return schema

def _literal_schema(self, literal_type: Any) -> core_schema.LiteralSchema:
"""
Expand Down Expand Up @@ -642,6 +646,8 @@ def apply_single_annotation(schema: core_schema.CoreSchema, metadata: Any) -> co
return apply_annotations(schema, metadata)
elif isinstance(metadata, FieldInfo):
schema = apply_annotations(schema, metadata.metadata)
if metadata.discriminator is not None:
schema = apply_discriminator(schema, metadata.discriminator)
# TODO setting a default here needs to be tested
return wrap_default(metadata, schema)

Expand Down Expand Up @@ -716,3 +722,85 @@ def _get_pydantic_modify_json_schema(obj: Any) -> typing.Callable[[JsonSchemaVal
return obj.__modify_schema__

return modify_js_function


def apply_discriminator(schema: core_schema.CoreSchema, discriminator: str) -> core_schema.CoreSchema:
# Eventually: should add support for other discriminator types, and explicitly specified choices
if schema['type'] != 'union':
raise TypeError('`discriminator` can only be used with `Union` type with more than one variant')
choices = [*schema['choices'][::-1]]
if len(choices) < 2:
raise TypeError('`discriminator` can only be used with `Union` type with more than one variant')

# TODO: Need to make sure nullable unions are handled properly
aliases = {discriminator: None} # use a dict to ensure order is preserved
tagged_union_choices: dict[str, str | core_schema.CoreSchema] = {}
while choices:
choice = choices.pop()
if choice['type'] == 'union':
choices.extend(choice['choices'])
continue

discriminator_values = _get_discriminator_values_for_choice(choice, discriminator, aliases)
for value in discriminator_values:
if isinstance(value, Enum):
value = value.value
value = str(value)
if value in tagged_union_choices and tagged_union_choices[value] != choice:
raise ValueError(f'Value {value!r} for discriminator {discriminator!r} mapped to multiple choices')
tagged_union_choices[value] = choice

if len(aliases) > 1:
schema_discriminator: str | list[list[str | int]] = [[alias] for alias in aliases]
else:
schema_discriminator = discriminator

return core_schema.tagged_union_schema(
choices=tagged_union_choices,
discriminator=schema_discriminator,
custom_error_type=schema.get('custom_error_type'),
custom_error_message=schema.get('custom_error_message'),
custom_error_context=schema.get('custom_error_context'),
strict=False,
from_attributes=True,
ref=schema.get('ref'),
metadata=schema.get('metadata'),
serialization=schema.get('serialization'),
)


def _get_discriminator_values_for_choice(
choice: core_schema.CoreSchema, discriminator: str, aliases: dict[str, None]
) -> list[Any]:
if choice['type'] == 'tagged-union':
values: list[Any] = []
for inner_choice in choice['choices'].values():
if isinstance(inner_choice, str):
continue
values.extend(_get_discriminator_values_for_choice(inner_choice, discriminator, aliases))
return values

elif choice['type'] == 'model':
model_name = choice['cls'].__name__
# Unpack ModelSchema into the inner TypedDictSchema
typed_dict_schema = choice['schema']
if discriminator not in typed_dict_schema['fields']:
raise PydanticUserError(f'Model {model_name!r} needs a discriminator field for key {discriminator!r}')
discriminator_field = typed_dict_schema['fields'][discriminator]

# TODO: Should maybe reflect whether populate_by_alias works or whatever
alias = discriminator_field.get('validation_alias', discriminator)
aliases[alias] = None

discriminator_schema = discriminator_field['schema']
if discriminator_schema['type'] == 'default':
# Ignore a wrapping default schema if present
discriminator_schema = discriminator_schema['schema']
if discriminator_schema['type'] != 'literal':
raise PydanticUserError(f'Field {discriminator!r} of model {model_name!r} needs to be a `Literal`')
return discriminator_schema['expected']

else:
raise TypeError(
f"{choice['type']!r} is not a valid discriminated union variant; " "should be a `BaseModel` or `dataclass`"
)
106 changes: 59 additions & 47 deletions tests/test_discrimated_union.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from enum import Enum
from enum import IntEnum
from typing import Generic, TypeVar, Union

import pytest
Expand All @@ -10,7 +10,6 @@
from pydantic.generics import GenericModel


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_only_union():
with pytest.raises(
TypeError, match='`discriminator` can only be used with `Union` type with more than one variant'
Expand All @@ -20,7 +19,6 @@ class Model(BaseModel):
x: str = Field(..., discriminator='qwe')


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_single_variant():
with pytest.raises(
TypeError, match='`discriminator` can only be used with `Union` type with more than one variant'
Expand All @@ -30,15 +28,15 @@ class Model(BaseModel):
x: Union[str] = Field(..., discriminator='qwe')


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_invalid_type():
with pytest.raises(TypeError, match="Type 'str' is not a valid `BaseModel` or `dataclass`"):
with pytest.raises(
TypeError, match="'str' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`"
):

class Model(BaseModel):
x: Union[str, int] = Field(..., discriminator='qwe')


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_defined_discriminator():
class Cat(BaseModel):
c: str
Expand All @@ -54,7 +52,6 @@ class Model(BaseModel):
number: int


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_literal_discriminator():
class Cat(BaseModel):
pet_type: int
Expand All @@ -71,7 +68,7 @@ class Model(BaseModel):
number: int


@pytest.mark.xfail(reason='working on V2')
@pytest.mark.xfail(reason='working on V2 - __root__')
def test_discriminated_union_root_same_discriminator():
class BlackCat(BaseModel):
pet_type: Literal['blackcat']
Expand All @@ -91,7 +88,7 @@ class Pet(BaseModel):
__root__: Union[Cat, Dog] = Field(..., discriminator='pet_type')


@pytest.mark.xfail(reason='working on V2')
@pytest.mark.xfail(reason='working on V2 - __root__')
def test_discriminated_union_validation():
class BlackCat(BaseModel):
pet_type: Literal['cat']
Expand Down Expand Up @@ -182,7 +179,6 @@ class Model(BaseModel):
assert isinstance(m.pet.__root__, WhiteCat)


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_annotated_union():
class BlackCat(BaseModel):
pet_type: Literal['cat']
Expand Down Expand Up @@ -210,29 +206,36 @@ class Model(BaseModel):
Model.model_validate({'pet': {'pet_typ': 'cat'}, 'number': 'x'})
assert exc_info.value.errors() == [
{
'ctx': {'discriminator': "'pet_type'"},
'input': {'pet_typ': 'cat'},
'loc': ('pet',),
'msg': "Discriminator 'pet_type' is missing in value",
'type': 'value_error.discriminated_union.missing_discriminator',
'ctx': {'discriminator_key': 'pet_type'},
'msg': "Unable to extract tag using discriminator 'pet_type'",
'type': 'union_tag_not_found',
},
{
'input': 'x',
'loc': ('number',),
'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer',
'type': 'int_parsing',
},
{'loc': ('number',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'},
]

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({'pet': {'pet_type': 'fish'}, 'number': 2})
assert exc_info.value.errors() == [
{
'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog'", 'tag': 'fish'},
'input': {'pet_type': 'fish'},
'loc': ('pet',),
'msg': "No match for discriminator 'pet_type' and value 'fish' " "(allowed values: 'cat', 'dog')",
'type': 'value_error.discriminated_union.invalid_discriminator',
'ctx': {'discriminator_key': 'pet_type', 'discriminator_value': 'fish', 'allowed_values': "'cat', 'dog'"},
},
'msg': "Input tag 'fish' found using 'pet_type' does not match any of the " "expected tags: 'cat', 'dog'",
'type': 'union_tag_invalid',
}
]

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({'pet': {'pet_type': 'dog'}, 'number': 2})
assert exc_info.value.errors() == [
{'loc': ('pet', 'Dog', 'dog_name'), 'msg': 'field required', 'type': 'value_error.missing'},
{'input': {'pet_type': 'dog'}, 'loc': ('pet', 'dog', 'dog_name'), 'msg': 'Field required', 'type': 'missing'}
]
m = Model.model_validate({'pet': {'pet_type': 'dog', 'dog_name': 'milou'}, 'number': 2})
assert isinstance(m.pet, Dog)
Expand All @@ -241,20 +244,22 @@ class Model(BaseModel):
Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'red'}, 'number': 2})
assert exc_info.value.errors() == [
{
'loc': ('pet', 'Union[BlackCat, WhiteCat]'),
'msg': "No match for discriminator 'color' and value 'red' " "(allowed values: 'black', 'white')",
'type': 'value_error.discriminated_union.invalid_discriminator',
'ctx': {'discriminator_key': 'color', 'discriminator_value': 'red', 'allowed_values': "'black', 'white'"},
'ctx': {'discriminator': "'color'", 'expected_tags': "'black', 'white'", 'tag': 'red'},
'input': {'color': 'red', 'pet_type': 'cat'},
'loc': ('pet', 'cat'),
'msg': "Input tag 'red' found using 'color' does not match any of the " "expected tags: 'black', 'white'",
'type': 'union_tag_invalid',
}
]

with pytest.raises(ValidationError) as exc_info:
Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white'}, 'number': 2})
assert exc_info.value.errors() == [
{
'loc': ('pet', 'Union[BlackCat, WhiteCat]', 'WhiteCat', 'white_infos'),
'msg': 'field required',
'type': 'value_error.missing',
'input': {'color': 'white', 'pet_type': 'cat'},
'loc': ('pet', 'cat', 'white', 'white_infos'),
'msg': 'Field required',
'type': 'missing',
}
]
m = Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white', 'white_infos': 'pika'}, 'number': 2})
Expand Down Expand Up @@ -291,7 +296,6 @@ class Top(BaseModel):
assert Top(sub=B(literal='b')).sub.literal == 'b'


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_int():
class A(BaseModel):
m: Literal[1]
Expand All @@ -300,24 +304,25 @@ class B(BaseModel):
m: Literal[2]

class Top(BaseModel):
sub: Union[A, B] = Field(..., discriminator='l')
sub: Union[A, B] = Field(..., discriminator='m')

assert isinstance(Top.model_validate({'sub': {'m': 2}}).sub, B)
with pytest.raises(ValidationError) as exc_info:
Top.model_validate({'sub': {'m': 3}})
assert exc_info.value.errors() == [
{
'ctx': {'discriminator': "'m'", 'expected_tags': "'1', '2'", 'tag': '3'},
'input': {'m': 3},
'loc': ('sub',),
'msg': "No match for discriminator 'l' and value 3 (allowed values: 1, 2)",
'type': 'value_error.discriminated_union.invalid_discriminator',
'ctx': {'discriminator_key': 'm', 'discriminator_value': 3, 'allowed_values': '1, 2'},
'msg': "Input tag '3' found using 'm' does not match any of the expected " "tags: '1', '2'",
'type': 'union_tag_invalid',
}
]


@pytest.mark.xfail(reason='working on V2')
def test_discriminated_union_enum():
class EnumValue(Enum):
# TODO: Make this work with a base Enum, not just IntEnum / StrEnum
class EnumValue(IntEnum):
a = 1
b = 2

Expand All @@ -331,23 +336,20 @@ class Top(BaseModel):
sub: Union[A, B] = Field(..., discriminator='m')

assert isinstance(Top.model_validate({'sub': {'m': EnumValue.b}}).sub, B)
assert isinstance(Top.model_validate({'sub': {'m': 2}}).sub, B)
with pytest.raises(ValidationError) as exc_info:
Top.model_validate({'sub': {'m': 3}})
assert exc_info.value.errors() == [
{
'ctx': {'discriminator': "'m'", 'expected_tags': "'1', '2'", 'tag': '3'},
'input': {'m': 3},
'loc': ('sub',),
'msg': "No match for discriminator 'm' and value 3 (allowed values: <EnumValue.a: 1>, <EnumValue.b: 2>)",
'type': 'value_error.discriminated_union.invalid_discriminator',
'ctx': {
'discriminator_key': 'm',
'discriminator_value': 3,
'allowed_values': '<EnumValue.a: 1>, <EnumValue.b: 2>',
},
'msg': "Input tag '3' found using 'm' does not match any of the expected " "tags: '1', '2'",
'type': 'union_tag_invalid',
}
]


@pytest.mark.xfail(reason='working on V2')
def test_alias_different():
class Cat(BaseModel):
pet_type: Literal['cat'] = Field(alias='U')
Expand All @@ -357,12 +359,22 @@ class Dog(BaseModel):
pet_type: Literal['dog'] = Field(alias='T')
d: str

with pytest.raises(
PydanticUserError, match=re.escape("Aliases for discriminator 'pet_type' must be the same (got T, U)")
):
class Model(BaseModel):
pet: Union[Cat, Dog] = Field(discriminator='pet_type')

class Model(BaseModel):
pet: Union[Cat, Dog] = Field(discriminator='pet_type')
Model.model_validate({'pet': {'U': 'cat', 'c': 'my_cat'}})
Model.model_validate({'pet': {'T': 'dog', 'd': 'my_dog'}})
with pytest.raises(ValidationError) as exc_info:
Model.model_validate({'pet': {'W': 'dog', 'd': 'my_dog'}})
assert exc_info.value.errors() == [
{
'ctx': {'discriminator': "'pet_type' | 'U' | 'T'"},
'input': {'W': 'dog', 'd': 'my_dog'},
'loc': ('pet',),
'msg': "Unable to extract tag using discriminator 'pet_type' | 'U' | 'T'",
'type': 'union_tag_not_found',
}
]


def test_alias_same():
Expand Down Expand Up @@ -402,7 +414,7 @@ class Model(BaseModel):
assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog)


@pytest.mark.xfail(reason='working on V2')
@pytest.mark.xfail(reason='working on V2 - generics')
def test_generic():
T = TypeVar('T')

Expand Down

0 comments on commit ecc5b85

Please sign in to comment.