Skip to content

Commit

Permalink
Add support for discriminated unions
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Mar 9, 2023
1 parent 53fcbec commit b708438
Show file tree
Hide file tree
Showing 5 changed files with 623 additions and 178 deletions.
114 changes: 113 additions & 1 deletion pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import re
import typing
import warnings
from enum import Enum
from typing import TYPE_CHECKING, Any

from annotated_types import BaseMetadata, GroupedMetadata
from pydantic_core import SchemaError, SchemaValidator, core_schema
from typing_extensions import Annotated, Literal, get_args, get_origin, is_typeddict

from ..errors import PydanticSchemaGenerationError
from ..errors import PydanticSchemaGenerationError, PydanticUserError
from ..fields import FieldInfo
from ..json_schema import JsonSchemaMetadata, JsonSchemaValue
from . import _fields, _typing_extra
Expand Down Expand Up @@ -245,6 +246,8 @@ def generate_field_schema(
"""
assert field_info.annotation is not None, 'field_info.annotation should not be None when generating a schema'
schema = self.generate_schema(field_info.annotation)
if field_info.discriminator is not None:
schema = apply_discriminator(schema, field_info.discriminator)
schema = apply_annotations(schema, field_info.metadata)

if not field_info.is_required():
Expand Down Expand Up @@ -682,6 +685,8 @@ def apply_single_annotation(schema: core_schema.CoreSchema, metadata: Any) -> co
return apply_annotations(schema, metadata)
elif isinstance(metadata, FieldInfo):
schema = apply_annotations(schema, metadata.metadata)
if metadata.discriminator is not None:
schema = apply_discriminator(schema, metadata.discriminator)
# TODO setting a default here needs to be tested
return wrap_default(metadata, schema)

Expand Down Expand Up @@ -766,3 +771,110 @@ def get_model_self_schema(cls: type[BaseModel]) -> core_schema.ModelSchema:
core_schema.definition_reference_schema(model_ref),
metadata=build_metadata_dict(js_metadata=model_js_metadata),
)


def apply_discriminator(schema: core_schema.CoreSchema, discriminator: str) -> core_schema.CoreSchema:
# Eventually: should add support for other discriminator types, and explicitly specified choices
if schema['type'] != 'union':
raise TypeError('`discriminator` can only be used with `Union` type with more than one variant')
choices = [*schema['choices'][::-1]]
if len(choices) < 2:
raise TypeError('`discriminator` can only be used with `Union` type with more than one variant')

# TODO: Need to make sure nullable unions are handled properly
aliases = {discriminator: None} # this is meant to behave like a set, but use a dict to ensure order is preserved
tagged_union_choices: dict[str | int, str | int | core_schema.CoreSchema] = {}
while choices:
choice = choices.pop()
if choice['type'] == 'union':
choices.extend(choice['choices'])
continue

discriminator_values = _get_discriminator_values_for_choice(choice, discriminator, aliases)
if discriminator_values:

def _handle_discriminator_value(value: Any, choice_override: int | str | None = None) -> None:
# This function accepts choice_override so that we can produce a schema that doesn't copy choices
if not isinstance(value, (int, str, Enum)):
raise ValueError(f'Invalid discriminator value {value!r}; must be a string, int, or Enum')
if isinstance(value, Enum):
value = value.value
if value in tagged_union_choices:
# Need to walk the choices dict until we get to a "real" choice
existing_choice = tagged_union_choices[value]
while isinstance(existing_choice, (str, int)):
existing_choice = tagged_union_choices[existing_choice]
if existing_choice != choice:
raise ValueError(
f'Value {value!r} for discriminator {discriminator!r} mapped to multiple choices'
)
else:
tagged_union_choices[value] = choice if choice_override is None else choice_override

primary_value = discriminator_values[0]
_handle_discriminator_value(primary_value)
for other_value in discriminator_values[1:]:
_handle_discriminator_value(other_value, primary_value)

if len(aliases) > 1:
schema_discriminator: str | list[list[str | int]] = [[alias] for alias in aliases]
else:
schema_discriminator = discriminator

return core_schema.tagged_union_schema(
choices=tagged_union_choices,
discriminator=schema_discriminator,
custom_error_type=schema.get('custom_error_type'),
custom_error_message=schema.get('custom_error_message'),
custom_error_context=schema.get('custom_error_context'),
strict=False,
ref=schema.get('ref'),
metadata=schema.get('metadata'),
serialization=schema.get('serialization'),
)


def _get_discriminator_values_for_choice(
choice: core_schema.CoreSchema, discriminator: str, aliases: dict[str, None]
) -> list[Any]:
if choice['type'] == 'tagged-union':
values: list[Any] = []
for inner_choice in choice['choices'].values():
if isinstance(inner_choice, (str, int)):
continue
values.extend(_get_discriminator_values_for_choice(inner_choice, discriminator, aliases))
return values

elif choice['type'] == 'model':
model_name = choice['cls'].__name__
# Unpack ModelSchema into the inner TypedDictSchema
inner_schema = choice['schema']
if inner_schema['type'] == 'definitions':
inner_schema = inner_schema['schema'] # unpack a definitions schema
if inner_schema['type'] == 'typed-dict':
typed_dict_schema = inner_schema
if discriminator not in typed_dict_schema['fields']:
raise PydanticUserError(f'Model {model_name!r} needs a discriminator field for key {discriminator!r}')
discriminator_field = typed_dict_schema['fields'][discriminator]

# TODO: Should maybe reflect whether populate_by_alias works or whatever
alias = discriminator_field.get('validation_alias', discriminator)
aliases[alias] = None

discriminator_schema = discriminator_field['schema']
if discriminator_schema['type'] == 'default':
# Ignore a wrapping default schema if present
discriminator_schema = discriminator_schema['schema']
if discriminator_schema['type'] != 'literal':
raise PydanticUserError(f'Field {discriminator!r} of model {model_name!r} needs to be a `Literal`')
return discriminator_schema['expected']
else:
raise TypeError(
f"Expected a CoreSchema with type='typed-dict' for model {model_name!r}, "
f"got type={inner_schema['type']!r}"
)

else:
raise TypeError(
f"{choice['type']!r} is not a valid discriminated union variant; " "should be a `BaseModel` or `dataclass`"
)
5 changes: 4 additions & 1 deletion pydantic/_internal/_model_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,20 @@ def complete_model_class(
serialization_functions.check_for_unused()

core_config = generate_config(cls)
pre_validator_core_config = core_config.copy()
cls.model_fields = fields
cls.__pydantic_validator__ = SchemaValidator(inner_schema, core_config)
model_post_init = '__pydantic_post_init__' if hasattr(cls, '__pydantic_post_init__') else None
js_metadata = cls.model_json_schema_metadata()
cls.__pydantic_core_schema__ = outer_schema = core_schema.model_schema(
cls,
inner_schema,
config=core_config,
config=pre_validator_core_config,
call_after_init=model_post_init,
metadata=build_metadata_dict(js_metadata=js_metadata),
)
# print(cls.__pydantic_core_schema__)
print(cls.__name__, cls.__pydantic_core_schema__['config']['title'])
cls.__pydantic_serializer__ = SchemaSerializer(outer_schema, core_config)
cls.__pydantic_model_complete__ = True

Expand Down
9 changes: 9 additions & 0 deletions pydantic/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,15 @@ def tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> JsonSche
generated[str(k)] = self.generate_inner(v).copy()
except PydanticInvalidForJsonSchema:
pass

# Populate the schema with any "indirect" references
for k, v in schema['choices'].items():
if isinstance(v, (str, int)):
while isinstance(schema['choices'][v], (str, int)):
v = schema['choices'][v]
if str(v) in generated: # PydanticInvalidForJsonSchema may have been raised above
generated[str(k)] = generated[str(v)]

json_schema: JsonSchemaValue = {'oneOf': list(generated.values())}

# This reflects the v1 behavior, but we may want to only include the discriminator based on dialect / etc.
Expand Down
Loading

0 comments on commit b708438

Please sign in to comment.