Skip to content

Commit

Permalink
Add discriminated union support (v2) (#5051)
Browse files Browse the repository at this point in the history
* Add support for discriminated unions

* Fix issue with duplicate schemas in oneOf

* Address TODOs

* Remove print statement

* Remove unwanted debug change

* Massively broken, but pushing as a checkpoint and reminder that this still needs more work

* Rework discriminated union to use a class-based approach to better handle recursion

* Fix 3.7 tests

* Make Json/HashableJson private

* Add a docstring to the final method in _ApplyInferredDiscriminator

* Fix discriminated union test from test_forward_ref.py

* Add test ensuring that literal enums produce valid JSON schema

* Get to 100% code coverage

* Improve comment

* Fix failing tests

* Convert __root__ tests to use Validator
  • Loading branch information
dmontagu committed Mar 15, 2023
1 parent 82bfb19 commit 14ebd21
Show file tree
Hide file tree
Showing 7 changed files with 1,988 additions and 555 deletions.
343 changes: 343 additions & 0 deletions pydantic/_internal/_discriminated_union.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pydantic/_internal/_generate_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ..errors import PydanticSchemaGenerationError, PydanticUserError
from ..fields import FieldInfo
from ..json_schema import JsonSchemaMetadata, JsonSchemaValue
from . import _fields, _typing_extra
from . import _discriminated_union, _fields, _typing_extra
from ._core_metadata import CoreMetadataHandler, build_metadata_dict
from ._core_utils import get_type_ref
from ._decorators import SerializationFunctions, Serializer, ValidationFunctions, Validator
Expand Down Expand Up @@ -259,6 +259,8 @@ def generate_field_schema(
"""
assert field_info.annotation is not None, 'field_info.annotation should not be None when generating a schema'
schema = self.generate_schema(field_info.annotation)
if field_info.discriminator is not None:
schema = _discriminated_union.apply_discriminator(schema, field_info.discriminator)
schema = apply_annotations(schema, field_info.metadata)

schema = apply_validators(schema, validator_functions.get_field_decorators(name))
Expand Down Expand Up @@ -712,6 +714,8 @@ def apply_single_annotation(schema: core_schema.CoreSchema, metadata: Any) -> co
return apply_annotations(schema, metadata)
elif isinstance(metadata, FieldInfo):
schema = apply_annotations(schema, metadata.metadata)
if metadata.discriminator is not None:
schema = _discriminated_union.apply_discriminator(schema, metadata.discriminator)
# TODO setting a default here needs to be tested
return wrap_default(metadata, schema)

Expand Down
32 changes: 30 additions & 2 deletions pydantic/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
from dataclasses import is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Counter, Dict, NewType, Sequence, cast
from typing import TYPE_CHECKING, Any, Callable, Counter, Dict, Iterable, List, NewType, Sequence, Tuple, Union, cast

from pydantic_core import CoreSchema, CoreSchemaType, core_schema
from pydantic_core.core_schema import TypedDictField
Expand Down Expand Up @@ -497,7 +497,18 @@ def tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> JsonSche
generated[str(k)] = self.generate_inner(v).copy()
except PydanticInvalidForJsonSchema:
pass
json_schema: JsonSchemaValue = {'oneOf': list(generated.values())}

# Populate the schema with any "indirect" references
for k, v in schema['choices'].items():
if isinstance(v, (str, int)):
while isinstance(schema['choices'][v], (str, int)):
v = schema['choices'][v]
if str(v) in generated:
# while it might seem unnecessary to check `if str(v) in generated`, a PydanticInvalidForJsonSchema
# may have been raised above, which would mean that the schema we want to reference won't be present
generated[str(k)] = generated[str(v)]

json_schema: JsonSchemaValue = {'oneOf': _deduplicate_schemas(generated.values())}

# This reflects the v1 behavior, but we may want to only include the discriminator based on dialect / etc.
if 'discriminator' in schema and isinstance(schema['discriminator'], str):
Expand Down Expand Up @@ -1062,3 +1073,20 @@ def model_schema(
) -> dict[str, Any]:
model = _utils.get_model(model)
return model.model_json_schema(by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator)


_Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None]
_HashableJson = Union[Tuple[Tuple[str, Any], ...], Tuple[Any, ...], str, int, float, bool, None]


def _deduplicate_schemas(schemas: Iterable[_Json]) -> list[_Json]:
return list({_make_json_hashable(schema): schema for schema in schemas}.values())


def _make_json_hashable(value: _Json) -> _HashableJson:
if isinstance(value, dict):
return tuple(sorted((k, _make_json_hashable(v)) for k, v in value.items()))
elif isinstance(value, list):
return tuple(_make_json_hashable(v) for v in value)
else:
return value
Loading

0 comments on commit 14ebd21

Please sign in to comment.