diff --git a/pydantic/_internal/_discriminated_union.py b/pydantic/_internal/_discriminated_union.py new file mode 100644 index 0000000000..45e6bc6c5d --- /dev/null +++ b/pydantic/_internal/_discriminated_union.py @@ -0,0 +1,343 @@ +from __future__ import annotations as _annotations + +from enum import Enum +from typing import Sequence + +from pydantic_core import core_schema + +from ..errors import PydanticUserError + + +def apply_discriminator(schema: core_schema.CoreSchema, discriminator: str) -> core_schema.CoreSchema: + return _ApplyInferredDiscriminator(discriminator).apply(schema) + + +class _ApplyInferredDiscriminator: + """ + This class is used to convert an input schema containing a union schema into one where that union is + replaced with a tagged-union, with all the associated debugging and performance benefits. + + This is done by: + * Validating that the input schema is compatible with the provided discriminator + * Introspecting the schema to determine which discriminator values should map to which union choices + * Handling various edge cases such as 'definitions', 'default', 'nullable' schemas, and more + + I have chosen to implement the conversion algorithm in this class, rather than a function, + to make it easier to maintain state while recursively walking the provided CoreSchema. + """ + + def __init__(self, discriminator: str): + # `discriminator` should be the name of the field which will serve as the discriminator. + # It must be the python name of the field, and *not* the field's alias. Note that as of now, + # all members of a discriminated union _must_ use a field with the same name as the discriminator. + # This may change if/when we expose a way to manually specify the TaggedUnionSchema's choices. + self.discriminator = discriminator + + # `_discriminator_alias` will hold the value, if present, of the alias for the discriminator + # + # Note: following the v1 implementation, we currently disallow the use of different aliases + # for different choices. This is not a limitation of pydantic_core, but if we try to handle + # this, the inference logic gets complicated very quickly, and could result in confusing + # debugging challenges for users making subtle mistakes. + # + # Rather than trying to do the most powerful inference possible, I think we should eventually + # expose a way to more-manually control the way the TaggedUnionSchema is constructed through + # the use of a new type which would be placed as an Annotation on the Union type. This would + # provide the full flexibility/power of pydantic_core's TaggedUnionSchema where necessary for + # more complex cases, without over-complicating the inference logic for the common cases. + self._discriminator_alias: str | None = None + + # `_should_be_nullable` indicates whether the converted union has `None` as an allowed value. + # If `None` is an acceptable value of the (possibly-wrapped) union, we ignore it while + # constructing the TaggedUnionSchema, but set the `_should_be_nullable` attribute to True. + # Once we have constructed the TaggedUnionSchema, if `_should_be_nullable` is True, we ensure + # that the final schema gets wrapped as a NullableSchema. This has the same semantics on the + # python side, but resolves the issue that `None` cannot correspond to any discriminator values. + self._should_be_nullable = False + + # `_is_nullable` is used to track if the final produced schema will definitely be nullable; + # we set it to True if the input schema is wrapped in a nullable schema that we know will be preserved + # as an indication that, even if None is discovered as one of the union choices, we will not need to wrap + # the final value in another nullable schema. + # + # This is more complicated than just checking for the final outermost schema having type 'nullable' thanks + # to the possible presence of other wrapper schemas such as DefinitionsSchema, WithDefaultSchema, etc. + self._is_nullable = False + + # `_choices_to_handle` serves as a stack of choices to add to the tagged union. Initially, choices + # from the union in the wrapped schema will be appended to this list, and the recursive choice-handling + # algorithm may add more choices to this stack as (nested) unions are encountered. + self._choices_to_handle: list[core_schema.CoreSchema] = [] + + # `_tagged_union_choices` is built during the call to `apply`, and will hold the choices to be included + # in the output TaggedUnionSchema that will replace the union from the input schema + self._tagged_union_choices: dict[str | int, str | int | core_schema.CoreSchema] = {} + + # `_used` is changed to True after applying the discriminator to prevent accidental re-use + self._used = False + + def apply(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """ + Return a new CoreSchema based on `schema` that uses a tagged-union with the discriminator provided + to this class. + """ + assert not self._used + schema = self._apply_to_root(schema) + if self._should_be_nullable and not self._is_nullable: + schema = core_schema.nullable_schema(schema) + self._used = True + return schema + + def _apply_to_root(self, schema: core_schema.CoreSchema) -> core_schema.CoreSchema: + """ + This method handles the outer-most stage of recursion over the input schema: + unwrapping nullable or definitions schemas, and calling the `_handle_choice` + method iteratively on the choices extracted (recursively) from the possibly-wrapped union. + """ + if schema['type'] == 'nullable': + self._is_nullable = True + wrapped = self._apply_to_root(schema['schema']) + nullable_wrapper = schema.copy() + nullable_wrapper['schema'] = wrapped + return nullable_wrapper + + if schema['type'] == 'definitions': + wrapped = self._apply_to_root(schema['schema']) + definitions_wrapper = schema.copy() + definitions_wrapper['schema'] = wrapped + return definitions_wrapper + + if schema['type'] != 'union': + raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') + + if len(schema['choices']) < 2: + raise TypeError('`discriminator` can only be used with `Union` type with more than one variant') + + # Reverse the choices list before extending the stack so that they get handled in the order they occur + self._choices_to_handle.extend(schema['choices'][::-1]) + while self._choices_to_handle: + choice = self._choices_to_handle.pop() + self._handle_choice(choice, None) + + if self._discriminator_alias is not None and self._discriminator_alias != self.discriminator: + # * We need to annotate `discriminator` as a union here to handle both branches of this conditional + # * We need to annotate `discriminator` as list[list[str | int]] and not list[list[str]] due to the + # invariance of list, and because list[list[str | int]] is the type of the discriminator argument + # to tagged_union_schema below + # * See the docstring of pydantic_core.core_schema.tagged_union_schema for more details about how to + # interpret the value of the discriminator argument to tagged_union_schema. (The list[list[str]] here + # is the appropriate way to provide a list of fallback attributes to check for a discriminator value.) + discriminator: str | list[list[str | int]] = [[self.discriminator], [self._discriminator_alias]] + else: + discriminator = self.discriminator + return core_schema.tagged_union_schema( + choices=self._tagged_union_choices, + discriminator=discriminator, + custom_error_type=schema.get('custom_error_type'), + custom_error_message=schema.get('custom_error_message'), + custom_error_context=schema.get('custom_error_context'), + strict=False, + ref=schema.get('ref'), + metadata=schema.get('metadata'), + serialization=schema.get('serialization'), + ) + + def _handle_choice(self, choice: core_schema.CoreSchema, definitions: list[core_schema.CoreSchema] | None) -> None: + """ + This method handles the "middle" stage of recursion over the input schema. + Specifically, it is responsible for handling each choice of the outermost union + (and any "coalesced" choices obtained from inner unions). + + Here, "handling" entails: + * Coalescing nested unions and compatible tagged-unions + * Tracking the presence of 'none' and 'nullable' schemas occurring as choices + * Validating that each allowed discriminator value maps to a unique choice + * Updating the _tagged_union_choices mapping that will ultimately be used to build the TaggedUnionSchema. + """ + if choice['type'] == 'none': + self._should_be_nullable = True + elif choice['type'] == 'definitions': + definitions = (definitions or []) + choice['definitions'] + self._handle_choice(choice['schema'], definitions) + elif choice['type'] == 'nullable': + self._should_be_nullable = True + self._handle_choice(choice['schema'], definitions) # unwrap the nullable schema + elif choice['type'] == 'union': + # Reverse the choices list before extending the stack so that they get handled in the order they occur + self._choices_to_handle.extend(choice['choices'][::-1]) + elif choice['type'] not in {'model', 'typed-dict', 'tagged-union', 'function', 'lax-or-strict'}: + # We should eventually handle 'definition-ref' as well + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + else: + if choice['type'] == 'tagged-union' and self._is_discriminator_shared(choice): + # In this case, this inner tagged-union is compatible with the outer tagged-union, + # and its choices can be coalesced into the outer TaggedUnionSchema. + subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] + # Reverse the choices list before extending the stack so that they get handled in the order they occur + self._choices_to_handle.extend(subchoices[::-1]) + return + + inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None) + if definitions: + choice = core_schema.definitions_schema(choice, definitions) + self._set_unique_choice_for_values(choice, inferred_discriminator_values) + + def _is_discriminator_shared(self, choice: core_schema.TaggedUnionSchema) -> bool: + """ + This method returns a boolean indicating whether the discriminator for the `choice` + is the same as that being used for the outermost tagged union. This is used to + determine whether this TaggedUnionSchema choice should be "coalesced" into the top level, + or whether it should be treated as a separate (nested) choice. + """ + inner_discriminator = choice['discriminator'] + return inner_discriminator == self.discriminator or ( + isinstance(inner_discriminator, list) + and (self.discriminator in inner_discriminator or [self.discriminator] in inner_discriminator) + ) + + def _infer_discriminator_values_for_choice( + self, choice: core_schema.CoreSchema, source_name: str | None + ) -> list[str | int]: + """ + This function recurses over `choice`, extracting all discriminator values that should map to this choice. + + `model_name` is accepted for the purpose of producing useful error messages. + """ + if choice['type'] == 'definitions': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) + + elif choice['type'] == 'function': + if choice['mode'] != 'plain': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=source_name) + else: + raise TypeError( + f'{choice["type"]!r} with mode={choice["mode"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + + elif choice['type'] == 'lax-or-strict': + return sorted( + set( + self._infer_discriminator_values_for_choice(choice['lax_schema'], source_name=None) + + self._infer_discriminator_values_for_choice(choice['strict_schema'], source_name=None) + ) + ) + + elif choice['type'] == 'tagged-union': + values: list[str | int] = [] + # Ignore str/int "choices" since these are just references to other choices + subchoices = [x for x in choice['choices'].values() if not isinstance(x, (str, int))] + for subchoice in subchoices: + subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) + values.extend(subchoice_values) + return values + + elif choice['type'] == 'union': + values = [] + for subchoice in choice['choices']: + subchoice_values = self._infer_discriminator_values_for_choice(subchoice, source_name=None) + values.extend(subchoice_values) + return values + + elif choice['type'] == 'nullable': + self._should_be_nullable = True + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=None) + + elif choice['type'] == 'model': + return self._infer_discriminator_values_for_choice(choice['schema'], source_name=choice['cls'].__name__) + + elif choice['type'] == 'typed-dict': + return self._infer_discriminator_values_for_typed_dict_choice(choice, source_name=source_name) + + else: + raise TypeError( + f'{choice["type"]!r} is not a valid discriminated union variant;' + ' should be a `BaseModel` or `dataclass`' + ) + + def _infer_discriminator_values_for_typed_dict_choice( + self, choice: core_schema.TypedDictSchema, source_name: str | None = None + ) -> list[str | int]: + """ + This method just extracts the _infer_discriminator_values_for_choice logic specific to TypedDictSchema + for the sake of readability. + """ + if source_name is None: + source = 'TypedDict' # We may eventually want to provide a more useful name + else: + source = f'Model {source_name!r}' + + field = choice['fields'].get(self.discriminator) + if field is None: + raise PydanticUserError(f'{source} needs a discriminator field for key {self.discriminator!r}') + alias = field.get('validation_alias', self.discriminator) + if not isinstance(alias, str): + raise TypeError(f'Alias {alias!r} is not supported in a discriminated union') + if self._discriminator_alias is None: + self._discriminator_alias = alias + elif self._discriminator_alias != alias: + raise PydanticUserError( + f'Aliases for discriminator {self.discriminator!r} must be the same ' + f'(got {alias}, {self._discriminator_alias})' + ) + return self._infer_discriminator_values_for_field(field['schema'], source) + + def _infer_discriminator_values_for_field(self, schema: core_schema.CoreSchema, source: str) -> list[str | int]: + """ + When inferring discriminator values for a field, we typically extract the expected values from a literal schema. + This function does that, but also handles nested unions and defaults. + """ + if schema['type'] == 'literal': + values = [] + for v in schema['expected']: + if isinstance(v, Enum): + v = v.value + if not isinstance(v, (str, int)): + raise TypeError(f'Unsupported value for discriminator field: {v!r}') + values.append(v) + return values + + elif schema['type'] == 'union': + # Generally when multiple values are allowed they should be placed in a single `Literal`, but + # we add this case to handle the situation where a field is annotated as a `Union` of `Literal`s. + # For example, this lets us handle `Union[Literal['key'], Union[Literal['Key'], Literal['KEY']]]` + values = [] + for choice in schema['choices']: + choice_values = self._infer_discriminator_values_for_field(choice, source) + values.extend(choice_values) + return values + + elif schema['type'] == 'default': + # This will happen if the field has a default value; we ignore it while extracting the discriminator values + return self._infer_discriminator_values_for_field(schema['schema'], source) + + else: + raise PydanticUserError(f'{source} needs field {self.discriminator!r} to be of type `Literal`') + + def _set_unique_choice_for_values(self, choice: core_schema.CoreSchema, values: Sequence[str | int]) -> None: + """ + This method updates `self.tagged_union_choices` so that all provided (discriminator) `values` map to the + provided `choice`, validating that none of these values already map to another (different) choice. + """ + primary_value: str | int | None = None + for discriminator_value in values: + if discriminator_value in self._tagged_union_choices: + # It is okay if `value` is already in tagged_union_choices as long as it maps to the same value. + # Because tagged_union_choices may map values to other values, we need to walk the choices dict + # until we get to a "real" choice, and confirm that is equal to the one assigned. + existing_choice = self._tagged_union_choices[discriminator_value] + while isinstance(existing_choice, (str, int)): + existing_choice = self._tagged_union_choices[existing_choice] + if existing_choice != choice: + raise TypeError( + f'Value {discriminator_value!r} for discriminator ' + f'{self.discriminator!r} mapped to multiple choices' + ) + elif primary_value is None: + self._tagged_union_choices[discriminator_value] = choice + primary_value = discriminator_value + else: + self._tagged_union_choices[discriminator_value] = primary_value diff --git a/pydantic/_internal/_generate_schema.py b/pydantic/_internal/_generate_schema.py index 2faac59805..2b9ab07cf2 100644 --- a/pydantic/_internal/_generate_schema.py +++ b/pydantic/_internal/_generate_schema.py @@ -18,7 +18,7 @@ from ..errors import PydanticSchemaGenerationError, PydanticUserError from ..fields import FieldInfo from ..json_schema import JsonSchemaMetadata, JsonSchemaValue -from . import _fields, _typing_extra +from . import _discriminated_union, _fields, _typing_extra from ._core_metadata import CoreMetadataHandler, build_metadata_dict from ._core_utils import get_type_ref from ._decorators import SerializationFunctions, Serializer, ValidationFunctions, Validator @@ -259,6 +259,8 @@ def generate_field_schema( """ assert field_info.annotation is not None, 'field_info.annotation should not be None when generating a schema' schema = self.generate_schema(field_info.annotation) + if field_info.discriminator is not None: + schema = _discriminated_union.apply_discriminator(schema, field_info.discriminator) schema = apply_annotations(schema, field_info.metadata) schema = apply_validators(schema, validator_functions.get_field_decorators(name)) @@ -712,6 +714,8 @@ def apply_single_annotation(schema: core_schema.CoreSchema, metadata: Any) -> co return apply_annotations(schema, metadata) elif isinstance(metadata, FieldInfo): schema = apply_annotations(schema, metadata.metadata) + if metadata.discriminator is not None: + schema = _discriminated_union.apply_discriminator(schema, metadata.discriminator) # TODO setting a default here needs to be tested return wrap_default(metadata, schema) diff --git a/pydantic/json_schema.py b/pydantic/json_schema.py index 021bceff6b..d7856614ae 100644 --- a/pydantic/json_schema.py +++ b/pydantic/json_schema.py @@ -4,7 +4,7 @@ import re from dataclasses import is_dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Counter, Dict, NewType, Sequence, cast +from typing import TYPE_CHECKING, Any, Callable, Counter, Dict, Iterable, List, NewType, Sequence, Tuple, Union, cast from pydantic_core import CoreSchema, CoreSchemaType, core_schema from pydantic_core.core_schema import TypedDictField @@ -497,7 +497,18 @@ def tagged_union_schema(self, schema: core_schema.TaggedUnionSchema) -> JsonSche generated[str(k)] = self.generate_inner(v).copy() except PydanticInvalidForJsonSchema: pass - json_schema: JsonSchemaValue = {'oneOf': list(generated.values())} + + # Populate the schema with any "indirect" references + for k, v in schema['choices'].items(): + if isinstance(v, (str, int)): + while isinstance(schema['choices'][v], (str, int)): + v = schema['choices'][v] + if str(v) in generated: + # while it might seem unnecessary to check `if str(v) in generated`, a PydanticInvalidForJsonSchema + # may have been raised above, which would mean that the schema we want to reference won't be present + generated[str(k)] = generated[str(v)] + + json_schema: JsonSchemaValue = {'oneOf': _deduplicate_schemas(generated.values())} # This reflects the v1 behavior, but we may want to only include the discriminator based on dialect / etc. if 'discriminator' in schema and isinstance(schema['discriminator'], str): @@ -1062,3 +1073,20 @@ def model_schema( ) -> dict[str, Any]: model = _utils.get_model(model) return model.model_json_schema(by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator) + + +_Json = Union[Dict[str, Any], List[Any], str, int, float, bool, None] +_HashableJson = Union[Tuple[Tuple[str, Any], ...], Tuple[Any, ...], str, int, float, bool, None] + + +def _deduplicate_schemas(schemas: Iterable[_Json]) -> list[_Json]: + return list({_make_json_hashable(schema): schema for schema in schemas}.values()) + + +def _make_json_hashable(value: _Json) -> _HashableJson: + if isinstance(value, dict): + return tuple(sorted((k, _make_json_hashable(v)) for k, v in value.items())) + elif isinstance(value, list): + return tuple(_make_json_hashable(v) for v in value) + else: + return value diff --git a/tests/test_discrimated_union.py b/tests/test_discrimated_union.py deleted file mode 100644 index 1076424055..0000000000 --- a/tests/test_discrimated_union.py +++ /dev/null @@ -1,434 +0,0 @@ -import re -from enum import Enum -from typing import Generic, TypeVar, Union - -import pytest -from typing_extensions import Annotated, Literal - -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from pydantic.errors import PydanticUserError - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_only_union(): - with pytest.raises( - TypeError, match='`discriminator` can only be used with `Union` type with more than one variant' - ): - - class Model(BaseModel): - x: str = Field(..., discriminator='qwe') - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_single_variant(): - with pytest.raises( - TypeError, match='`discriminator` can only be used with `Union` type with more than one variant' - ): - - class Model(BaseModel): - x: Union[str] = Field(..., discriminator='qwe') - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_invalid_type(): - with pytest.raises(TypeError, match="Type 'str' is not a valid `BaseModel` or `dataclass`"): - - class Model(BaseModel): - x: Union[str, int] = Field(..., discriminator='qwe') - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_defined_discriminator(): - class Cat(BaseModel): - c: str - - class Dog(BaseModel): - pet_type: Literal['dog'] - d: str - - with pytest.raises(PydanticUserError, match="Model 'Cat' needs a discriminator field for key 'pet_type'"): - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') - number: int - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_literal_discriminator(): - class Cat(BaseModel): - pet_type: int - c: str - - class Dog(BaseModel): - pet_type: Literal['dog'] - d: str - - with pytest.raises(PydanticUserError, match="Field 'pet_type' of model 'Cat' needs to be a `Literal`"): - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') - number: int - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_root_same_discriminator(): - class BlackCat(BaseModel): - pet_type: Literal['blackcat'] - - class WhiteCat(BaseModel): - pet_type: Literal['whitecat'] - - class Cat(BaseModel): - __root__: Union[BlackCat, WhiteCat] - - class Dog(BaseModel): - pet_type: Literal['dog'] - - with pytest.raises(PydanticUserError, match="Field 'pet_type' is not the same for all submodels of 'Cat'"): - - class Pet(BaseModel): - __root__: Union[Cat, Dog] = Field(..., discriminator='pet_type') - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_validation(): - class BlackCat(BaseModel): - pet_type: Literal['cat'] - color: Literal['black'] - black_infos: str - - class WhiteCat(BaseModel): - pet_type: Literal['cat'] - color: Literal['white'] - white_infos: str - - class Cat(BaseModel): - __root__: Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] - - class Dog(BaseModel): - pet_type: Literal['dog'] - d: str - - class Lizard(BaseModel): - pet_type: Literal['reptile', 'lizard'] - m: str - - class Model(BaseModel): - pet: Annotated[Union[Cat, Dog, Lizard], Field(discriminator='pet_type')] - number: int - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_typ': 'cat'}, 'number': 'x'}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet',), - 'msg': "Discriminator 'pet_type' is missing in value", - 'type': 'value_error.discriminated_union.missing_discriminator', - 'ctx': {'discriminator_key': 'pet_type'}, - }, - {'loc': ('number',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': 'fish', 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet',), - 'msg': "Discriminator 'pet_type' is missing in value", - 'type': 'value_error.discriminated_union.missing_discriminator', - 'ctx': {'discriminator_key': 'pet_type'}, - }, - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'fish'}, 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet',), - 'msg': ( - "No match for discriminator 'pet_type' and value 'fish' " - "(allowed values: 'cat', 'dog', 'reptile', 'lizard')" - ), - 'type': 'value_error.discriminated_union.invalid_discriminator', - 'ctx': { - 'discriminator_key': 'pet_type', - 'discriminator_value': 'fish', - 'allowed_values': "'cat', 'dog', 'reptile', 'lizard'", - }, - }, - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'lizard'}, 'number': 2}) - assert exc_info.value.errors() == [ - {'loc': ('pet', 'Lizard', 'l'), 'msg': 'field required', 'type': 'value_error.missing'}, - ] - - m = Model.model_validate({'pet': {'pet_type': 'lizard', 'l': 'pika'}, 'number': 2}) - assert isinstance(m.pet, Lizard) - assert m.model_dump() == {'pet': {'pet_type': 'lizard', 'l': 'pika'}, 'number': 2} - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white'}, 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet', 'Cat', '__root__', 'WhiteCat', 'white_infos'), - 'msg': 'field required', - 'type': 'value_error.missing', - } - ] - m = Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white', 'white_infos': 'pika'}, 'number': 2}) - assert isinstance(m.pet.__root__, WhiteCat) - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_annotated_union(): - class BlackCat(BaseModel): - pet_type: Literal['cat'] - color: Literal['black'] - black_infos: str - - class WhiteCat(BaseModel): - pet_type: Literal['cat'] - color: Literal['white'] - white_infos: str - - Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] - - class Dog(BaseModel): - pet_type: Literal['dog'] - dog_name: str - - Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] - - class Model(BaseModel): - pet: Pet - number: int - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_typ': 'cat'}, 'number': 'x'}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet',), - 'msg': "Discriminator 'pet_type' is missing in value", - 'type': 'value_error.discriminated_union.missing_discriminator', - 'ctx': {'discriminator_key': 'pet_type'}, - }, - {'loc': ('number',), 'msg': 'value is not a valid integer', 'type': 'type_error.integer'}, - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'fish'}, 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet',), - 'msg': "No match for discriminator 'pet_type' and value 'fish' " "(allowed values: 'cat', 'dog')", - 'type': 'value_error.discriminated_union.invalid_discriminator', - 'ctx': {'discriminator_key': 'pet_type', 'discriminator_value': 'fish', 'allowed_values': "'cat', 'dog'"}, - }, - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'dog'}, 'number': 2}) - assert exc_info.value.errors() == [ - {'loc': ('pet', 'Dog', 'dog_name'), 'msg': 'field required', 'type': 'value_error.missing'}, - ] - m = Model.model_validate({'pet': {'pet_type': 'dog', 'dog_name': 'milou'}, 'number': 2}) - assert isinstance(m.pet, Dog) - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'red'}, 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet', 'Union[BlackCat, WhiteCat]'), - 'msg': "No match for discriminator 'color' and value 'red' " "(allowed values: 'black', 'white')", - 'type': 'value_error.discriminated_union.invalid_discriminator', - 'ctx': {'discriminator_key': 'color', 'discriminator_value': 'red', 'allowed_values': "'black', 'white'"}, - } - ] - - with pytest.raises(ValidationError) as exc_info: - Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white'}, 'number': 2}) - assert exc_info.value.errors() == [ - { - 'loc': ('pet', 'Union[BlackCat, WhiteCat]', 'WhiteCat', 'white_infos'), - 'msg': 'field required', - 'type': 'value_error.missing', - } - ] - m = Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white', 'white_infos': 'pika'}, 'number': 2}) - assert isinstance(m.pet, WhiteCat) - - -def test_discriminated_union_basemodel_instance_value(): - class A(BaseModel): - foo: Literal['a'] - - class B(BaseModel): - foo: Literal['b'] - - class Top(BaseModel): - sub: Union[A, B] = Field(..., discriminator='foo') - - t = Top(sub=A(foo='a')) - assert isinstance(t, Top) - - -def test_discriminated_union_basemodel_instance_value_with_alias(): - class A(BaseModel): - literal: Literal['a'] = Field(alias='lit') - - class B(BaseModel): - model_config = ConfigDict(populate_by_name=True) - literal: Literal['b'] = Field(alias='lit') - - class Top(BaseModel): - sub: Union[A, B] = Field(..., discriminator='literal') - - assert Top(sub=A(lit='a')).sub.literal == 'a' - assert Top(sub=B(lit='b')).sub.literal == 'b' - assert Top(sub=B(literal='b')).sub.literal == 'b' - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_int(): - class A(BaseModel): - m: Literal[1] - - class B(BaseModel): - m: Literal[2] - - class Top(BaseModel): - sub: Union[A, B] = Field(..., discriminator='l') - - assert isinstance(Top.model_validate({'sub': {'m': 2}}).sub, B) - with pytest.raises(ValidationError) as exc_info: - Top.model_validate({'sub': {'m': 3}}) - assert exc_info.value.errors() == [ - { - 'loc': ('sub',), - 'msg': "No match for discriminator 'l' and value 3 (allowed values: 1, 2)", - 'type': 'value_error.discriminated_union.invalid_discriminator', - 'ctx': {'discriminator_key': 'm', 'discriminator_value': 3, 'allowed_values': '1, 2'}, - } - ] - - -@pytest.mark.xfail(reason='working on V2') -def test_discriminated_union_enum(): - class EnumValue(Enum): - a = 1 - b = 2 - - class A(BaseModel): - m: Literal[EnumValue.a] - - class B(BaseModel): - m: Literal[EnumValue.b] - - class Top(BaseModel): - sub: Union[A, B] = Field(..., discriminator='m') - - assert isinstance(Top.model_validate({'sub': {'m': EnumValue.b}}).sub, B) - with pytest.raises(ValidationError) as exc_info: - Top.model_validate({'sub': {'m': 3}}) - assert exc_info.value.errors() == [ - { - 'loc': ('sub',), - 'msg': "No match for discriminator 'm' and value 3 (allowed values: , )", - 'type': 'value_error.discriminated_union.invalid_discriminator', - 'ctx': { - 'discriminator_key': 'm', - 'discriminator_value': 3, - 'allowed_values': ', ', - }, - } - ] - - -@pytest.mark.xfail(reason='working on V2') -def test_alias_different(): - class Cat(BaseModel): - pet_type: Literal['cat'] = Field(alias='U') - c: str - - class Dog(BaseModel): - pet_type: Literal['dog'] = Field(alias='T') - d: str - - with pytest.raises( - PydanticUserError, match=re.escape("Aliases for discriminator 'pet_type' must be the same (got T, U)") - ): - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(discriminator='pet_type') - - -def test_alias_same(): - class Cat(BaseModel): - pet_type: Literal['cat'] = Field(alias='typeOfPet') - c: str - - class Dog(BaseModel): - pet_type: Literal['dog'] = Field(alias='typeOfPet') - d: str - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(discriminator='pet_type') - - assert Model(**{'pet': {'typeOfPet': 'dog', 'd': 'milou'}}).pet.pet_type == 'dog' - - -def test_nested(): - class Cat(BaseModel): - pet_type: Literal['cat'] - name: str - - class Dog(BaseModel): - pet_type: Literal['dog'] - name: str - - CommonPet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] - - class Lizard(BaseModel): - pet_type: Literal['reptile', 'lizard'] - name: str - - class Model(BaseModel): - pet: Union[CommonPet, Lizard] = Field(..., discriminator='pet_type') - n: int - - assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog) - - -@pytest.mark.xfail(reason='working on V2') -def test_generic(): - T = TypeVar('T') - - class Success(BaseModel, Generic[T]): - type: Literal['Success'] = 'Success' - data: T - - class Failure(BaseModel): - type: Literal['Failure'] = 'Failure' - error_message: str - - class Container(BaseModel, Generic[T]): - result: Union[Success[T], Failure] = Field(discriminator='type') - - with pytest.raises(ValidationError, match="Discriminator 'type' is missing in value"): - Container[str].model_validate({'result': {}}) - - with pytest.raises( - ValidationError, - match=re.escape("No match for discriminator 'type' and value 'Other' (allowed values: 'Success', 'Failure')"), - ): - Container[str].model_validate({'result': {'type': 'Other'}}) - - with pytest.raises( - ValidationError, match=re.escape('Container[str]\nresult -> Success[str] -> data\n field required') - ): - Container[str].model_validate({'result': {'type': 'Success'}}) - - # coercion is done properly - assert Container[str].model_validate({'result': {'type': 'Success', 'data': 1}}).result.data == '1' diff --git a/tests/test_discriminated_union.py b/tests/test_discriminated_union.py new file mode 100644 index 0000000000..ffa56f13eb --- /dev/null +++ b/tests/test_discriminated_union.py @@ -0,0 +1,1034 @@ +import re +import sys +from enum import Enum, IntEnum +from typing import Generic, Optional, TypeVar, Union + +import pytest +from dirty_equals import HasRepr, IsStr +from pydantic_core import SchemaValidator, core_schema +from typing_extensions import Annotated, Literal + +from pydantic import BaseModel, ConfigDict, Field, ValidationError, Validator +from pydantic._internal._discriminated_union import apply_discriminator +from pydantic.errors import PydanticUserError + + +def test_discriminated_union_only_union(): + with pytest.raises( + TypeError, match='`discriminator` can only be used with `Union` type with more than one variant' + ): + + class Model(BaseModel): + x: str = Field(..., discriminator='qwe') + + +def test_discriminated_union_single_variant(): + with pytest.raises( + TypeError, match='`discriminator` can only be used with `Union` type with more than one variant' + ): + + class Model(BaseModel): + x: Union[str] = Field(..., discriminator='qwe') + + +def test_discriminated_union_invalid_type(): + with pytest.raises( + TypeError, match="'str' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`" + ): + + class Model(BaseModel): + x: Union[str, int] = Field(..., discriminator='qwe') + + +def test_discriminated_union_defined_discriminator(): + class Cat(BaseModel): + c: str + + class Dog(BaseModel): + pet_type: Literal['dog'] + d: str + + with pytest.raises(PydanticUserError, match="Model 'Cat' needs a discriminator field for key 'pet_type'"): + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') + number: int + + +def test_discriminated_union_literal_discriminator(): + class Cat(BaseModel): + pet_type: int + c: str + + class Dog(BaseModel): + pet_type: Literal['dog'] + d: str + + with pytest.raises(PydanticUserError, match="Model 'Cat' needs field 'pet_type' to be of type `Literal`"): + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(..., discriminator='pet_type') + number: int + + +def test_discriminated_union_root_same_discriminator(): + class BlackCat(BaseModel): + pet_type: Literal['blackcat'] + + class WhiteCat(BaseModel): + pet_type: Literal['whitecat'] + + Cat = Union[BlackCat, WhiteCat] + + class Dog(BaseModel): + pet_type: Literal['dog'] + + CatDog = Validator(Annotated[Union[Cat, Dog], Field(..., discriminator='pet_type')]) + CatDog({'pet_type': 'blackcat'}) + CatDog({'pet_type': 'whitecat'}) + CatDog({'pet_type': 'dog'}) + with pytest.raises(ValidationError) as exc_info: + CatDog({'pet_type': 'llama'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'blackcat', 'whitecat', 'dog'", 'tag': 'llama'}, + 'input': {'pet_type': 'llama'}, + 'loc': (), + 'msg': "Input tag 'llama' found using 'pet_type' does not match any of the " + "expected tags: 'blackcat', 'whitecat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + +def test_discriminated_union_validation(): + class BlackCat(BaseModel): + pet_type: Literal['cat'] + color: Literal['black'] + black_infos: str + + class WhiteCat(BaseModel): + pet_type: Literal['cat'] + color: Literal['white'] + white_infos: str + + Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] + + class Dog(BaseModel): + pet_type: Literal['dog'] + d: str + + class Lizard(BaseModel): + pet_type: Literal['reptile', 'lizard'] + m: str + + class Model(BaseModel): + pet: Annotated[Union[Cat, Dog, Lizard], Field(discriminator='pet_type')] + number: int + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_typ': 'cat'}, 'number': 'x'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'"}, + 'input': {'pet_typ': 'cat'}, + 'loc': ('pet',), + 'msg': "Unable to extract tag using discriminator 'pet_type'", + 'type': 'union_tag_not_found', + }, + { + 'input': 'x', + 'loc': ('number',), + 'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer', + 'type': 'int_parsing', + }, + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': 'fish', 'number': 2}) + assert exc_info.value.errors() == [ + { + 'input': 'fish', + 'loc': ('pet',), + 'msg': 'Input should be a valid dictionary or instance to extract fields ' 'from', + 'type': 'dict_attributes_type', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'fish'}, 'number': 2}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog', 'reptile', 'lizard'", 'tag': 'fish'}, + 'input': {'pet_type': 'fish'}, + 'loc': ('pet',), + 'msg': "Input tag 'fish' found using 'pet_type' does not match any of the " + "expected tags: 'cat', 'dog', 'reptile', 'lizard'", + 'type': 'union_tag_invalid', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'lizard'}, 'number': 2}) + assert exc_info.value.errors() == [ + {'input': {'pet_type': 'lizard'}, 'loc': ('pet', 'lizard', 'm'), 'msg': 'Field required', 'type': 'missing'} + ] + + m = Model.model_validate({'pet': {'pet_type': 'lizard', 'm': 'pika'}, 'number': 2}) + assert isinstance(m.pet, Lizard) + assert m.model_dump() == {'pet': {'pet_type': 'lizard', 'm': 'pika'}, 'number': 2} + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white'}, 'number': 2}) + assert exc_info.value.errors() == [ + { + 'input': {'color': 'white', 'pet_type': 'cat'}, + 'loc': ('pet', 'cat', 'white', 'white_infos'), + 'msg': 'Field required', + 'type': 'missing', + } + ] + m = Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white', 'white_infos': 'pika'}, 'number': 2}) + assert isinstance(m.pet, WhiteCat) + + +def test_discriminated_annotated_union(): + class BlackCat(BaseModel): + pet_type: Literal['cat'] + color: Literal['black'] + black_infos: str + + class WhiteCat(BaseModel): + pet_type: Literal['cat'] + color: Literal['white'] + white_infos: str + + Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] + + class Dog(BaseModel): + pet_type: Literal['dog'] + dog_name: str + + Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] + + class Model(BaseModel): + pet: Pet + number: int + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_typ': 'cat'}, 'number': 'x'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'"}, + 'input': {'pet_typ': 'cat'}, + 'loc': ('pet',), + 'msg': "Unable to extract tag using discriminator 'pet_type'", + 'type': 'union_tag_not_found', + }, + { + 'input': 'x', + 'loc': ('number',), + 'msg': 'Input should be a valid integer, unable to parse string as an ' 'integer', + 'type': 'int_parsing', + }, + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'fish'}, 'number': 2}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog'", 'tag': 'fish'}, + 'input': {'pet_type': 'fish'}, + 'loc': ('pet',), + 'msg': "Input tag 'fish' found using 'pet_type' does not match any of the " "expected tags: 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'dog'}, 'number': 2}) + assert exc_info.value.errors() == [ + {'input': {'pet_type': 'dog'}, 'loc': ('pet', 'dog', 'dog_name'), 'msg': 'Field required', 'type': 'missing'} + ] + m = Model.model_validate({'pet': {'pet_type': 'dog', 'dog_name': 'milou'}, 'number': 2}) + assert isinstance(m.pet, Dog) + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'red'}, 'number': 2}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'color'", 'expected_tags': "'black', 'white'", 'tag': 'red'}, + 'input': {'color': 'red', 'pet_type': 'cat'}, + 'loc': ('pet', 'cat'), + 'msg': "Input tag 'red' found using 'color' does not match any of the " "expected tags: 'black', 'white'", + 'type': 'union_tag_invalid', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white'}, 'number': 2}) + assert exc_info.value.errors() == [ + { + 'input': {'color': 'white', 'pet_type': 'cat'}, + 'loc': ('pet', 'cat', 'white', 'white_infos'), + 'msg': 'Field required', + 'type': 'missing', + } + ] + m = Model.model_validate({'pet': {'pet_type': 'cat', 'color': 'white', 'white_infos': 'pika'}, 'number': 2}) + assert isinstance(m.pet, WhiteCat) + + +def test_discriminated_union_basemodel_instance_value(): + class A(BaseModel): + foo: Literal['a'] + + class B(BaseModel): + foo: Literal['b'] + + class Top(BaseModel): + sub: Union[A, B] = Field(..., discriminator='foo') + + t = Top(sub=A(foo='a')) + assert isinstance(t, Top) + + +def test_discriminated_union_basemodel_instance_value_with_alias(): + class A(BaseModel): + literal: Literal['a'] = Field(alias='lit') + + class B(BaseModel): + model_config = ConfigDict(populate_by_name=True) + literal: Literal['b'] = Field(alias='lit') + + class Top(BaseModel): + sub: Union[A, B] = Field(..., discriminator='literal') + + with pytest.raises(ValidationError) as exc_info: + Top(sub=A(literal='a')) + # TODO: Adding this note here that we should make sure the produced error messages for DiscriminatedUnion + # have the same behavior as elsewhere when aliases are involved. + # (I.e., possibly using the alias value as the 'loc') + assert exc_info.value.errors() == [ + {'input': {'literal': 'a'}, 'loc': ('literal',), 'msg': 'Field required', 'type': 'missing'} + ] + assert Top(sub=A(lit='a')).sub.literal == 'a' + assert Top(sub=B(lit='b')).sub.literal == 'b' + assert Top(sub=B(literal='b')).sub.literal == 'b' + + +def test_discriminated_union_int(): + class A(BaseModel): + m: Literal[1] + + class B(BaseModel): + m: Literal[2] + + class Top(BaseModel): + sub: Union[A, B] = Field(..., discriminator='m') + + assert isinstance(Top.model_validate({'sub': {'m': 2}}).sub, B) + with pytest.raises(ValidationError) as exc_info: + Top.model_validate({'sub': {'m': 3}}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'m'", 'expected_tags': '1, 2', 'tag': '3'}, + 'input': {'m': 3}, + 'loc': ('sub',), + 'msg': "Input tag '3' found using 'm' does not match any of the expected " "tags: 1, 2", + 'type': 'union_tag_invalid', + } + ] + + +class FooIntEnum(int, Enum): + pass + + +class FooStrEnum(str, Enum): + pass + + +ENUM_TEST_CASES = [ + pytest.param(Enum, {'a': 1, 'b': 2}, marks=pytest.mark.xfail(reason='Plain Enum not yet supported')), + pytest.param(Enum, {'a': 'v_a', 'b': 'v_b'}, marks=pytest.mark.xfail(reason='Plain Enum not yet supported')), + (FooIntEnum, {'a': 1, 'b': 2}), + (IntEnum, {'a': 1, 'b': 2}), + (FooStrEnum, {'a': 'v_a', 'b': 'v_b'}), +] +if sys.version_info >= (3, 11): + from enum import StrEnum + + ENUM_TEST_CASES.append((StrEnum, {'a': 'v_a', 'b': 'v_b'})) + + +@pytest.mark.parametrize('base_class,choices', ENUM_TEST_CASES) +def test_discriminated_union_enum(base_class, choices): + EnumValue = base_class('EnumValue', choices) + + class A(BaseModel): + m: Literal[EnumValue.a] + + class B(BaseModel): + m: Literal[EnumValue.b] + + class Top(BaseModel): + sub: Union[A, B] = Field(..., discriminator='m') + + assert isinstance(Top.model_validate({'sub': {'m': EnumValue.b}}).sub, B) + assert isinstance(Top.model_validate({'sub': {'m': EnumValue.b.value}}).sub, B) + with pytest.raises(ValidationError) as exc_info: + Top.model_validate({'sub': {'m': 3}}) + + expected_tags = f'{EnumValue.a.value!r}, {EnumValue.b.value!r}' + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'m'", 'expected_tags': expected_tags, 'tag': '3'}, + 'input': {'m': 3}, + 'loc': ('sub',), + 'msg': f"Input tag '3' found using 'm' does not match any of the expected tags: {expected_tags}", + 'type': 'union_tag_invalid', + } + ] + + +def test_alias_different(): + class Cat(BaseModel): + pet_type: Literal['cat'] = Field(alias='U') + c: str + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='T') + d: str + + with pytest.raises(TypeError, match=re.escape("Aliases for discriminator 'pet_type' must be the same (got T, U)")): + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(discriminator='pet_type') + + +def test_alias_same(): + class Cat(BaseModel): + pet_type: Literal['cat'] = Field(alias='typeOfPet') + c: str + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='typeOfPet') + d: str + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(discriminator='pet_type') + + assert Model(**{'pet': {'typeOfPet': 'dog', 'd': 'milou'}}).pet.pet_type == 'dog' + + +def test_nested(): + class Cat(BaseModel): + pet_type: Literal['cat'] + name: str + + class Dog(BaseModel): + pet_type: Literal['dog'] + name: str + + CommonPet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] + + class Lizard(BaseModel): + pet_type: Literal['reptile', 'lizard'] + name: str + + class Model(BaseModel): + pet: Union[CommonPet, Lizard] = Field(..., discriminator='pet_type') + n: int + + assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog) + + +def test_generic(): + T = TypeVar('T') + + class Success(BaseModel, Generic[T]): + type: Literal['Success'] = 'Success' + data: T + + class Failure(BaseModel): + type: Literal['Failure'] = 'Failure' + error_message: str + + class Container(BaseModel, Generic[T]): + result: Union[Success[T], Failure] = Field(discriminator='type') + + with pytest.raises(ValidationError, match="Unable to extract tag using discriminator 'type'"): + Container[str].model_validate({'result': {}}) + + with pytest.raises( + ValidationError, + match=re.escape( + "Input tag 'Other' found using 'type' does not match any of the expected tags: 'Success', 'Failure'" + ), + ): + Container[str].model_validate({'result': {'type': 'Other'}}) + + # See https://github.com/pydantic/pydantic-core/issues/425 for why this is set weirdly; this is an unrelated issue + # If/when the issue is fixed, the following line should replace the current title = 'Failure' line + # title = 'Container[str]' + title = 'Failure' + + with pytest.raises(ValidationError, match=f'{title}\nresult -> Success -> data') as exc_info: + Container[str].model_validate({'result': {'type': 'Success'}}) + assert exc_info.value.errors() == [ + {'input': {'type': 'Success'}, 'loc': ('result', 'Success', 'data'), 'msg': 'Field required', 'type': 'missing'} + ] + + # invalid types error + with pytest.raises(ValidationError) as exc_info: + Container[str].model_validate({'result': {'type': 'Success', 'data': 1}}) + assert exc_info.value.errors() == [ + { + 'input': 1, + 'loc': ('result', 'Success', 'data'), + 'msg': 'Input should be a valid string', + 'type': 'string_type', + } + ] + + assert Container[str].model_validate({'result': {'type': 'Success', 'data': '1'}}).result.data == '1' + + +def test_optional_union(): + class Cat(BaseModel): + pet_type: Literal['cat'] + name: str + + class Dog(BaseModel): + pet_type: Literal['dog'] + name: str + + class Pet(BaseModel): + pet: Optional[Union[Cat, Dog]] = Field(discriminator='pet_type') + + assert Pet(pet={'pet_type': 'cat', 'name': 'Milo'}).model_dump() == {'pet': {'name': 'Milo', 'pet_type': 'cat'}} + assert Pet(pet={'pet_type': 'dog', 'name': 'Otis'}).model_dump() == {'pet': {'name': 'Otis', 'pet_type': 'dog'}} + assert Pet(pet=None).model_dump() == {'pet': None} + + with pytest.raises(ValidationError) as exc_info: + Pet() + assert exc_info.value.errors() == [{'input': {}, 'loc': ('pet',), 'msg': 'Field required', 'type': 'missing'}] + + with pytest.raises(ValidationError) as exc_info: + Pet(pet={'name': 'Benji'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'"}, + 'input': {'name': 'Benji'}, + 'loc': ('pet',), + 'msg': "Unable to extract tag using discriminator 'pet_type'", + 'type': 'union_tag_not_found', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Pet(pet={'pet_type': 'lizard'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog'", 'tag': 'lizard'}, + 'input': {'pet_type': 'lizard'}, + 'loc': ('pet',), + 'msg': "Input tag 'lizard' found using 'pet_type' does not match any of the " "expected tags: 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + +def test_optional_union_with_defaults(): + class Cat(BaseModel): + pet_type: Literal['cat'] = 'cat' + name: str + + class Dog(BaseModel): + pet_type: Literal['dog'] = 'dog' + name: str + + class Pet(BaseModel): + pet: Optional[Union[Cat, Dog]] = Field(default=None, discriminator='pet_type') + + assert Pet(pet={'pet_type': 'cat', 'name': 'Milo'}).model_dump() == {'pet': {'name': 'Milo', 'pet_type': 'cat'}} + assert Pet(pet={'pet_type': 'dog', 'name': 'Otis'}).model_dump() == {'pet': {'name': 'Otis', 'pet_type': 'dog'}} + assert Pet(pet=None).model_dump() == {'pet': None} + assert Pet().model_dump() == {'pet': None} + + with pytest.raises(ValidationError) as exc_info: + Pet(pet={'name': 'Benji'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'"}, + 'input': {'name': 'Benji'}, + 'loc': ('pet',), + 'msg': "Unable to extract tag using discriminator 'pet_type'", + 'type': 'union_tag_not_found', + } + ] + + with pytest.raises(ValidationError) as exc_info: + Pet(pet={'pet_type': 'lizard'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog'", 'tag': 'lizard'}, + 'input': {'pet_type': 'lizard'}, + 'loc': ('pet',), + 'msg': "Input tag 'lizard' found using 'pet_type' does not match any of the " "expected tags: 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + +def test_aliases_matching_is_not_sufficient() -> None: + class Case1(BaseModel): + kind_one: Literal['1'] = Field(alias='kind') + + class Case2(BaseModel): + kind_two: Literal['2'] = Field(alias='kind') + + with pytest.raises(PydanticUserError, match="Model 'Case1' needs a discriminator field for key 'kind'"): + + class TaggedParent(BaseModel): + tagged: Union[Case1, Case2] = Field(discriminator='kind') + + +def test_nested_optional_unions() -> None: + class Cat(BaseModel): + pet_type: Literal['cat'] = 'cat' + + class Dog(BaseModel): + pet_type: Literal['dog'] = 'dog' + + class Lizard(BaseModel): + pet_type: Literal['lizard', 'reptile'] = 'lizard' + + MaybeCatDog = Annotated[Optional[Union[Cat, Dog]], Field(discriminator='pet_type')] + MaybeDogLizard = Annotated[Union[Dog, Lizard, None], Field(discriminator='pet_type')] + + class Pet(BaseModel): + pet: Union[MaybeCatDog, MaybeDogLizard] = Field(discriminator='pet_type') + + Pet.model_validate({'pet': {'pet_type': 'dog'}}) + Pet.model_validate({'pet': {'pet_type': 'cat'}}) + Pet.model_validate({'pet': {'pet_type': 'lizard'}}) + Pet.model_validate({'pet': {'pet_type': 'reptile'}}) + Pet.model_validate({'pet': None}) + + with pytest.raises(ValidationError) as exc_info: + Pet.model_validate({'pet': {'pet_type': None}}) + assert exc_info.value.errors() == [ + {'input': None, 'loc': ('pet',), 'msg': 'Input should be a valid string', 'type': 'string_type'} + ] + + with pytest.raises(ValidationError) as exc_info: + Pet.model_validate({'pet': {'pet_type': 'fox'}}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type'", 'expected_tags': "'cat', 'dog', 'lizard', 'reptile'", 'tag': 'fox'}, + 'input': {'pet_type': 'fox'}, + 'loc': ('pet',), + 'msg': "Input tag 'fox' found using 'pet_type' does not match any of the " + "expected tags: 'cat', 'dog', 'lizard', 'reptile'", + 'type': 'union_tag_invalid', + } + ] + + +def test_nested_discriminated_union() -> None: + class Cat(BaseModel): + pet_type: Literal['cat', 'CAT'] + + class Dog(BaseModel): + pet_type: Literal['dog', 'DOG'] + + class Lizard(BaseModel): + pet_type: Literal['lizard', 'LIZARD'] + + CatDog = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] + CatDogLizard = Annotated[Union[CatDog, Lizard], Field(discriminator='pet_type')] + + class Pet(BaseModel): + pet: CatDogLizard + + Pet.model_validate({'pet': {'pet_type': 'dog'}}) + Pet.model_validate({'pet': {'pet_type': 'cat'}}) + Pet.model_validate({'pet': {'pet_type': 'lizard'}}) + + with pytest.raises(ValidationError) as exc_info: + Pet.model_validate({'pet': {'pet_type': 'reptile'}}) + assert exc_info.value.errors() == [ + { + 'ctx': { + 'discriminator': "'pet_type'", + 'expected_tags': "'cat', 'dog', 'lizard', 'CAT', 'DOG', 'LIZARD'", + 'tag': 'reptile', + }, + 'input': {'pet_type': 'reptile'}, + 'loc': ('pet',), + 'msg': "Input tag 'reptile' found using 'pet_type' does not match any of the " + "expected tags: 'cat', 'dog', 'lizard', 'CAT', 'DOG', 'LIZARD'", + 'type': 'union_tag_invalid', + } + ] + + +def test_unions_of_optionals() -> None: + class Cat(BaseModel): + pet_type: Literal['cat'] = Field(alias='typeOfPet') + c: str + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='typeOfPet') + d: str + + class Lizard(BaseModel): + pet_type: Literal['lizard'] = Field(alias='typeOfPet') + + MaybeCat = Annotated[Union[Cat, None], 'some annotation'] + MaybeDogLizard = Annotated[Optional[Union[Dog, Lizard]], 'some other annotation'] + + class Model(BaseModel): + maybe_pet: Union[MaybeCat, MaybeDogLizard] = Field(discriminator='pet_type') + + assert Model(**{'maybe_pet': None}).maybe_pet is None + assert Model(**{'maybe_pet': {'typeOfPet': 'dog', 'd': 'milou'}}).maybe_pet.pet_type == 'dog' + assert Model(**{'maybe_pet': {'typeOfPet': 'lizard'}}).maybe_pet.pet_type == 'lizard' + + +def test_union_discriminator_literals() -> None: + class Cat(BaseModel): + pet_type: Union[Literal['cat'], Literal['CAT']] = Field(alias='typeOfPet') + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='typeOfPet') + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(discriminator='pet_type') + + assert Model(**{'pet': {'typeOfPet': 'dog'}}).pet.pet_type == 'dog' + assert Model(**{'pet': {'typeOfPet': 'cat'}}).pet.pet_type == 'cat' + assert Model(**{'pet': {'typeOfPet': 'CAT'}}).pet.pet_type == 'CAT' + with pytest.raises(ValidationError) as exc_info: + Model(**{'pet': {'typeOfPet': 'Cat'}}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'pet_type' | 'typeOfPet'", 'expected_tags': "'cat', 'dog', 'CAT'", 'tag': 'Cat'}, + 'input': {'typeOfPet': 'Cat'}, + 'loc': ('pet',), + 'msg': "Input tag 'Cat' found using 'pet_type' | 'typeOfPet' does not match " + "any of the expected tags: 'cat', 'dog', 'CAT'", + 'type': 'union_tag_invalid', + } + ] + + +def test_none_schema() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))} + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + schema = core_schema.union_schema(cat, dog, core_schema.none_schema()) + schema = apply_discriminator(schema, 'kind') + validator = SchemaValidator(schema) + assert validator.validate_python({'kind': 'cat'})['kind'] == 'cat' + assert validator.validate_python({'kind': 'dog'})['kind'] == 'dog' + assert validator.validate_python(None) is None + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'kind': 'lizard'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'kind'", 'expected_tags': "'cat', 'dog'", 'tag': 'lizard'}, + 'input': {'kind': 'lizard'}, + 'loc': (), + 'msg': "Input tag 'lizard' found using 'kind' does not match any of the " "expected tags: 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + +def test_nested_unwrapping() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))} + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + schema = core_schema.union_schema(cat, dog) + for _ in range(3): + schema = core_schema.nullable_schema(schema) + schema = core_schema.nullable_schema(schema) + schema = core_schema.definitions_schema(schema, []) + schema = core_schema.definitions_schema(schema, []) + + schema = apply_discriminator(schema, 'kind') + + validator = SchemaValidator(schema) + assert validator.validate_python({'kind': 'cat'})['kind'] == 'cat' + assert validator.validate_python({'kind': 'dog'})['kind'] == 'dog' + assert validator.validate_python(None) is None + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'kind': 'lizard'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'kind'", 'expected_tags': "'cat', 'dog'", 'tag': 'lizard'}, + 'input': {'kind': 'lizard'}, + 'loc': (), + 'msg': "Input tag 'lizard' found using 'kind' does not match any of the " "expected tags: 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + +def test_distinct_choices() -> None: + class Cat(BaseModel): + pet_type: Literal['cat', 'dog'] = Field(alias='typeOfPet') + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='typeOfPet') + + with pytest.raises(TypeError, match="Value 'dog' for discriminator 'pet_type' mapped to multiple choices"): + + class Model(BaseModel): + pet: Union[Cat, Dog] = Field(discriminator='pet_type') + + +def test_invalid_discriminated_union_type() -> None: + class Cat(BaseModel): + pet_type: Literal['cat'] = Field(alias='typeOfPet') + + class Dog(BaseModel): + pet_type: Literal['dog'] = Field(alias='typeOfPet') + + with pytest.raises( + TypeError, match="'str' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`" + ): + + class Model(BaseModel): + pet: Union[Cat, Dog, str] = Field(discriminator='pet_type') + + +def test_single_item_union_error() -> None: + fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('only_choice'))} + schema = core_schema.union_schema(core_schema.typed_dict_schema(fields=fields)) + with pytest.raises( + TypeError, match='`discriminator` can only be used with `Union` type with more than one variant' + ): + apply_discriminator(schema, 'kind') + + +def test_invalid_alias() -> None: + cat_fields = { + 'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'), validation_alias=['cat', 'CAT']) + } + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + schema = core_schema.union_schema(cat, dog) + + with pytest.raises(TypeError, match=re.escape("Alias ['cat', 'CAT'] is not supported in a discriminated union")): + apply_discriminator(schema, 'kind') + + +def test_invalid_discriminator_type() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.int_schema())} + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.str_schema())} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + + with pytest.raises(TypeError, match=re.escape("TypedDict needs field 'kind' to be of type `Literal`")): + apply_discriminator(core_schema.union_schema(cat, dog), 'kind') + + +def test_missing_discriminator_field() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.int_schema())} + dog_fields = {} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + + with pytest.raises(TypeError, match=re.escape("TypedDict needs a discriminator field for key 'kind'")): + apply_discriminator(core_schema.union_schema(dog, cat), 'kind') + + +def test_invalid_discriminator_value() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema(None))} + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema(1.5))} + cat = core_schema.typed_dict_schema(cat_fields) + dog = core_schema.typed_dict_schema(dog_fields) + + with pytest.raises(TypeError, match=re.escape('Unsupported value for discriminator field: None')): + apply_discriminator(core_schema.union_schema(cat, dog), 'kind') + + with pytest.raises(TypeError, match=re.escape('Unsupported value for discriminator field: 1.5')): + apply_discriminator(core_schema.union_schema(dog, cat), 'kind') + + +def test_wrap_function_schema() -> None: + cat_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))} + dog_fields = {'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))} + cat = core_schema.function_wrap_schema(lambda x, y, z: None, core_schema.typed_dict_schema(cat_fields)) + dog = core_schema.typed_dict_schema(dog_fields) + schema = core_schema.union_schema(cat, dog) + + assert apply_discriminator(schema, 'kind') == { + 'choices': { + 'cat': { + 'function': HasRepr(IsStr(regex=r'\. at 0x[0-9a-fA-F]+>')), + 'mode': 'wrap', + 'schema': { + 'fields': {'kind': {'schema': {'expected': ('cat',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + 'type': 'function', + }, + 'dog': {'fields': {'kind': {'schema': {'expected': ('dog',), 'type': 'literal'}}}, 'type': 'typed-dict'}, + }, + 'discriminator': 'kind', + 'strict': False, + 'type': 'tagged-union', + } + + +def test_plain_function_schema_is_invalid() -> None: + with pytest.raises( + TypeError, + match="'function' with mode='plain' is not a valid discriminated union variant; " + "should be a `BaseModel` or `dataclass`", + ): + apply_discriminator( + core_schema.union_schema( + core_schema.function_plain_schema(lambda x, y: None), + core_schema.int_schema(), + ), + 'kind', + ) + + +def test_invalid_str_choice_discriminator_values() -> None: + cat = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))}) + dog = core_schema.str_schema() + + schema = core_schema.union_schema( + cat, + # NOTE: Wrapping the union with a validator results in failure to more thoroughly decompose the tagged union. + # I think this would be difficult to avoid in the general case, and I would suggest that we not attempt to do + # more than this until presented with scenarios where it is helpful/necessary. + core_schema.function_wrap_schema(lambda x, y, z: x, dog), + ) + + with pytest.raises( + TypeError, match="'str' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`" + ): + apply_discriminator(schema, 'kind') + + +def test_lax_or_strict_definitions() -> None: + cat = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))}) + lax_dog = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('DOG'))}) + strict_dog = core_schema.definitions_schema( + core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))}), + [core_schema.int_schema(ref='my-int-definition')], + ) + dog = core_schema.definitions_schema( + core_schema.lax_or_strict_schema(lax_schema=lax_dog, strict_schema=strict_dog), + [core_schema.str_schema(ref='my-str-definition')], + ) + discriminated_schema = apply_discriminator(core_schema.union_schema(cat, dog), 'kind') + assert discriminated_schema == { + 'choices': { + 'DOG': { + 'definitions': [{'ref': 'my-str-definition', 'type': 'str'}], + 'schema': { + 'lax_schema': { + 'fields': {'kind': {'schema': {'expected': ('DOG',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + 'strict_schema': { + 'definitions': [{'ref': 'my-int-definition', 'type': 'int'}], + 'schema': { + 'fields': {'kind': {'schema': {'expected': ('dog',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + 'type': 'definitions', + }, + 'type': 'lax-or-strict', + }, + 'type': 'definitions', + }, + 'cat': {'fields': {'kind': {'schema': {'expected': ('cat',), 'type': 'literal'}}}, 'type': 'typed-dict'}, + 'dog': 'DOG', + }, + 'discriminator': 'kind', + 'strict': False, + 'type': 'tagged-union', + } + + +def test_wrapped_nullable_union() -> None: + cat = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('cat'))}) + dog = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('dog'))}) + ant = core_schema.typed_dict_schema({'kind': core_schema.typed_dict_field(core_schema.literal_schema('ant'))}) + + schema = core_schema.union_schema( + ant, + # NOTE: Wrapping the union with a validator results in failure to more thoroughly decompose the tagged union. + # I think this would be difficult to avoid in the general case, and I would suggest that we not attempt to do + # more than this until presented with scenarios where it is helpful/necessary. + core_schema.function_wrap_schema( + lambda x, y, z: x, core_schema.nullable_schema(core_schema.union_schema(cat, dog)) + ), + ) + discriminated_schema = apply_discriminator(schema, 'kind') + validator = SchemaValidator(discriminated_schema) + assert validator.validate_python({'kind': 'ant'})['kind'] == 'ant' + assert validator.validate_python({'kind': 'cat'})['kind'] == 'cat' + assert validator.validate_python(None) is None + with pytest.raises(ValidationError) as exc_info: + validator.validate_python({'kind': 'armadillo'}) + assert exc_info.value.errors() == [ + { + 'ctx': {'discriminator': "'kind'", 'expected_tags': "'ant', 'cat', 'dog'", 'tag': 'armadillo'}, + 'input': {'kind': 'armadillo'}, + 'loc': (), + 'msg': "Input tag 'armadillo' found using 'kind' does not match any of the " + "expected tags: 'ant', 'cat', 'dog'", + 'type': 'union_tag_invalid', + } + ] + + assert discriminated_schema == { + 'schema': { + 'choices': { + 'ant': { + 'fields': {'kind': {'schema': {'expected': ('ant',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + 'cat': { + 'function': HasRepr(IsStr(regex=r'\. at 0x[0-9a-fA-F]+>')), + 'mode': 'wrap', + 'schema': { + 'schema': { + 'choices': ( + { + 'fields': {'kind': {'schema': {'expected': ('cat',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + { + 'fields': {'kind': {'schema': {'expected': ('dog',), 'type': 'literal'}}}, + 'type': 'typed-dict', + }, + ), + 'type': 'union', + }, + 'type': 'nullable', + }, + 'type': 'function', + }, + 'dog': 'cat', + }, + 'discriminator': 'kind', + 'strict': False, + 'type': 'tagged-union', + }, + 'type': 'nullable', + } diff --git a/tests/test_forward_ref.py b/tests/test_forward_ref.py index 1a25bd0ac5..9c54fe3873 100644 --- a/tests/test_forward_ref.py +++ b/tests/test_forward_ref.py @@ -540,7 +540,6 @@ class NestedTuple(BaseModel): assert obj.model_dump() == {'x': (1, {'x': (2, {'x': (3, None)})})} -@pytest.mark.xfail(reason='TODO discriminator') def test_discriminated_union_forward_ref(create_module): @create_module def module(): @@ -551,7 +550,9 @@ def module(): from pydantic import BaseModel, Field class Pet(BaseModel): - __root__: Union['Cat', 'Dog'] = Field(..., discriminator='type') + pet: Union['Cat', 'Dog'] = Field(discriminator='type') + + model_config = dict(undefined_types_warning=False) class Cat(BaseModel): type: Literal['cat'] @@ -560,28 +561,38 @@ class Dog(BaseModel): type: Literal['dog'] with pytest.raises(PydanticUserError, match='`Pet` is not fully defined; you should define `Cat`'): - module.Pet.model_validate({'type': 'pika'}) + module.Pet.model_validate({'pet': {'type': 'pika'}}) module.Pet.model_rebuild() - with pytest.raises(ValidationError, match="No match for discriminator 'type' and value 'pika'"): - module.Pet.model_validate({'type': 'pika'}) + with pytest.raises( + ValidationError, + match="Input tag 'pika' found using 'type' does not match any of the expected tags: 'cat', 'dog'", + ): + module.Pet.model_validate({'pet': {'type': 'pika'}}) assert module.Pet.model_json_schema() == { 'title': 'Pet', - 'discriminator': {'propertyName': 'type', 'mapping': {'cat': '#/$defs/Cat', 'dog': '#/$defs/Dog'}}, - 'oneOf': [{'$ref': '#/$defs/Cat'}, {'$ref': '#/$defs/Dog'}], + 'required': ['pet'], + 'type': 'object', + 'properties': { + 'pet': { + 'title': 'Pet', + 'discriminator': {'mapping': {'cat': '#/$defs/Cat', 'dog': '#/$defs/Dog'}, 'propertyName': 'type'}, + 'oneOf': [{'$ref': '#/$defs/Cat'}, {'$ref': '#/$defs/Dog'}], + } + }, '$defs': { 'Cat': { 'title': 'Cat', 'type': 'object', - 'properties': {'type': {'title': 'Type', 'enum': ['cat'], 'type': 'string'}}, + 'properties': {'type': {'const': 'cat', 'title': 'Type'}}, 'required': ['type'], }, 'Dog': { 'title': 'Dog', 'type': 'object', - 'properties': {'type': {'title': 'Type', 'enum': ['dog'], 'type': 'string'}}, + 'properties': {'type': {'const': 'dog', 'title': 'Type'}}, 'required': ['type'], }, }, diff --git a/tests/test_json_schema.py b/tests/test_json_schema.py index 496c827a4b..8188a5bf24 100644 --- a/tests/test_json_schema.py +++ b/tests/test_json_schema.py @@ -2664,8 +2664,386 @@ class MyModel(BaseModel): } -@pytest.mark.xfail(reason='working on V2 - discriminated union') def test_discriminated_union(): + class Cat(BaseModel): + pet_type: Literal['cat'] + + class Dog(BaseModel): + pet_type: Literal['dog'] + + class Lizard(BaseModel): + pet_type: Literal['reptile', 'lizard'] + + class Model(BaseModel): + pet: Union[Cat, Dog, Lizard] = Field(..., discriminator='pet_type') + + assert Model.model_json_schema() == { + '$defs': { + 'Cat': { + 'properties': {'pet_type': {'const': 'cat', 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Cat', + 'type': 'object', + }, + 'Dog': { + 'properties': {'pet_type': {'const': 'dog', 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Dog', + 'type': 'object', + }, + 'Lizard': { + 'properties': {'pet_type': {'enum': ['reptile', 'lizard'], 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Lizard', + 'type': 'object', + }, + }, + 'properties': { + 'pet': { + 'discriminator': { + 'mapping': { + 'cat': '#/$defs/Cat', + 'dog': '#/$defs/Dog', + 'lizard': '#/$defs/Lizard', + 'reptile': '#/$defs/Lizard', + }, + 'propertyName': 'pet_type', + }, + 'oneOf': [ + {'$ref': '#/$defs/Cat'}, + {'$ref': '#/$defs/Dog'}, + {'$ref': '#/$defs/Lizard'}, + ], + 'title': 'Pet', + } + }, + 'required': ['pet'], + 'title': 'Model', + 'type': 'object', + } + + +def test_discriminated_annotated_union(): + class Cat(BaseModel): + pet_type: Literal['cat'] + + class Dog(BaseModel): + pet_type: Literal['dog'] + + class Lizard(BaseModel): + pet_type: Literal['reptile', 'lizard'] + + class Model(BaseModel): + pet: Annotated[Union[Cat, Dog, Lizard], Field(..., discriminator='pet_type')] + + assert Model.model_json_schema() == { + '$defs': { + 'Cat': { + 'properties': {'pet_type': {'const': 'cat', 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Cat', + 'type': 'object', + }, + 'Dog': { + 'properties': {'pet_type': {'const': 'dog', 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Dog', + 'type': 'object', + }, + 'Lizard': { + 'properties': {'pet_type': {'enum': ['reptile', 'lizard'], 'title': 'Pet Type'}}, + 'required': ['pet_type'], + 'title': 'Lizard', + 'type': 'object', + }, + }, + 'properties': { + 'pet': { + 'discriminator': { + 'mapping': { + 'cat': '#/$defs/Cat', + 'dog': '#/$defs/Dog', + 'lizard': '#/$defs/Lizard', + 'reptile': '#/$defs/Lizard', + }, + 'propertyName': 'pet_type', + }, + 'oneOf': [ + {'$ref': '#/$defs/Cat'}, + {'$ref': '#/$defs/Dog'}, + {'$ref': '#/$defs/Lizard'}, + ], + 'title': 'Pet', + } + }, + 'required': ['pet'], + 'title': 'Model', + 'type': 'object', + } + + +def test_nested_discriminated_union(): + class BlackCatWithHeight(BaseModel): + color: Literal['black'] + info: Literal['height'] + height: float + + class BlackCatWithWeight(BaseModel): + color: Literal['black'] + info: Literal['weight'] + weight: float + + BlackCat = Annotated[Union[BlackCatWithHeight, BlackCatWithWeight], Field(discriminator='info')] + + class WhiteCat(BaseModel): + color: Literal['white'] + white_cat_info: str + + class Cat(BaseModel): + pet: Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] + + assert Cat.model_json_schema() == { + '$defs': { + 'BlackCatWithHeight': { + 'properties': { + 'color': {'const': 'black', 'title': 'Color'}, + 'height': {'title': 'Height', 'type': 'number'}, + 'info': {'const': 'height', 'title': 'Info'}, + }, + 'required': ['color', 'info', 'height'], + 'title': 'BlackCatWithHeight', + 'type': 'object', + }, + 'BlackCatWithWeight': { + 'properties': { + 'color': {'const': 'black', 'title': 'Color'}, + 'info': {'const': 'weight', 'title': 'Info'}, + 'weight': {'title': 'Weight', 'type': 'number'}, + }, + 'required': ['color', 'info', 'weight'], + 'title': 'BlackCatWithWeight', + 'type': 'object', + }, + 'WhiteCat': { + 'properties': { + 'color': {'const': 'white', 'title': 'Color'}, + 'white_cat_info': {'title': 'White Cat Info', 'type': 'string'}, + }, + 'required': ['color', 'white_cat_info'], + 'title': 'WhiteCat', + 'type': 'object', + }, + }, + 'properties': { + 'pet': { + 'discriminator': { + 'mapping': { + 'black': { + 'discriminator': { + 'mapping': { + 'height': '#/$defs/BlackCatWithHeight', + 'weight': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [{'$ref': '#/$defs/BlackCatWithHeight'}, {'$ref': '#/$defs/BlackCatWithWeight'}], + }, + 'white': '#/$defs/WhiteCat', + }, + 'propertyName': 'color', + }, + 'oneOf': [ + { + 'discriminator': { + 'mapping': {'height': '#/$defs/BlackCatWithHeight', 'weight': '#/$defs/BlackCatWithWeight'}, + 'propertyName': 'info', + }, + 'oneOf': [{'$ref': '#/$defs/BlackCatWithHeight'}, {'$ref': '#/$defs/BlackCatWithWeight'}], + }, + {'$ref': '#/$defs/WhiteCat'}, + ], + 'title': 'Pet', + } + }, + 'required': ['pet'], + 'title': 'Cat', + 'type': 'object', + } + + +def test_deeper_nested_discriminated_annotated_union(): + class BlackCatWithHeight(BaseModel): + pet_type: Literal['cat'] + color: Literal['black'] + info: Literal['height'] + black_infos: str + + class BlackCatWithWeight(BaseModel): + pet_type: Literal['cat'] + color: Literal['black'] + info: Literal['weight'] + black_infos: str + + BlackCat = Annotated[Union[BlackCatWithHeight, BlackCatWithWeight], Field(discriminator='info')] + + class WhiteCat(BaseModel): + pet_type: Literal['cat'] + color: Literal['white'] + white_infos: str + + Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] + + class Dog(BaseModel): + pet_type: Literal['dog'] + dog_name: str + + Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] + + class Model(BaseModel): + pet: Pet + number: int + + assert Model.model_json_schema() == { + '$defs': { + 'BlackCatWithHeight': { + 'properties': { + 'black_infos': {'title': 'Black ' 'Infos', 'type': 'string'}, + 'color': {'const': 'black', 'title': 'Color'}, + 'info': {'const': 'height', 'title': 'Info'}, + 'pet_type': {'const': 'cat', 'title': 'Pet ' 'Type'}, + }, + 'required': ['pet_type', 'color', 'info', 'black_infos'], + 'title': 'BlackCatWithHeight', + 'type': 'object', + }, + 'BlackCatWithWeight': { + 'properties': { + 'black_infos': {'title': 'Black ' 'Infos', 'type': 'string'}, + 'color': {'const': 'black', 'title': 'Color'}, + 'info': {'const': 'weight', 'title': 'Info'}, + 'pet_type': {'const': 'cat', 'title': 'Pet ' 'Type'}, + }, + 'required': ['pet_type', 'color', 'info', 'black_infos'], + 'title': 'BlackCatWithWeight', + 'type': 'object', + }, + 'Dog': { + 'properties': { + 'dog_name': {'title': 'Dog Name', 'type': 'string'}, + 'pet_type': {'const': 'dog', 'title': 'Pet Type'}, + }, + 'required': ['pet_type', 'dog_name'], + 'title': 'Dog', + 'type': 'object', + }, + 'WhiteCat': { + 'properties': { + 'color': {'const': 'white', 'title': 'Color'}, + 'pet_type': {'const': 'cat', 'title': 'Pet Type'}, + 'white_infos': {'title': 'White Infos', 'type': 'string'}, + }, + 'required': ['pet_type', 'color', 'white_infos'], + 'title': 'WhiteCat', + 'type': 'object', + }, + }, + 'properties': { + 'number': {'title': 'Number', 'type': 'integer'}, + 'pet': { + 'discriminator': { + 'mapping': { + 'cat': { + 'discriminator': { + 'mapping': { + 'black': { + 'discriminator': { + 'mapping': { + 'height': '#/$defs/BlackCatWithHeight', + 'weight': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + 'white': '#/$defs/WhiteCat', + }, + 'propertyName': 'color', + }, + 'oneOf': [ + { + 'discriminator': { + 'mapping': { + 'height': '#/$defs/BlackCatWithHeight', + 'weight': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + {'$ref': '#/$defs/WhiteCat'}, + ], + }, + 'dog': '#/$defs/Dog', + }, + 'propertyName': 'pet_type', + }, + 'oneOf': [ + { + 'discriminator': { + 'mapping': { + 'black': { + 'discriminator': { + 'mapping': { + 'height': '#/$defs/BlackCatWithHeight', + 'weight': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + 'white': '#/$defs/WhiteCat', + }, + 'propertyName': 'color', + }, + 'oneOf': [ + { + 'discriminator': { + 'mapping': { + 'height': '#/$defs/BlackCatWithHeight', + 'weight': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + {'$ref': '#/$defs/WhiteCat'}, + ], + }, + {'$ref': '#/$defs/Dog'}, + ], + 'title': 'Pet', + }, + }, + 'required': ['pet', 'number'], + 'title': 'Model', + 'type': 'object', + } + + +@pytest.mark.xfail(reason='working on V2 - __root__') +def test_discriminated_union_root(): class BlackCat(BaseModel): pet_type: Literal['cat'] color: Literal['black'] @@ -2752,31 +3130,42 @@ class Model(BaseModel): } -@pytest.mark.xfail(reason='working on V2 - discriminated union') -def test_discriminated_annotated_union(): +def test_discriminated_annotated_union_literal_enum(): + class PetType(Enum): + cat = 'cat' + dog = 'dog' + + class PetColor(str, Enum): + black = 'black' + white = 'white' + + class PetInfo(Enum): + height = 0 + weight = 1 + class BlackCatWithHeight(BaseModel): - pet_type: Literal['cat'] - color: Literal['black'] - info: Literal['height'] + pet_type: Literal[PetType.cat] + color: Literal[PetColor.black] + info: Literal[PetInfo.height] black_infos: str class BlackCatWithWeight(BaseModel): - pet_type: Literal['cat'] - color: Literal['black'] - info: Literal['weight'] + pet_type: Literal[PetType.cat] + color: Literal[PetColor.black] + info: Literal[PetInfo.weight] black_infos: str BlackCat = Annotated[Union[BlackCatWithHeight, BlackCatWithWeight], Field(discriminator='info')] class WhiteCat(BaseModel): - pet_type: Literal['cat'] - color: Literal['white'] + pet_type: Literal[PetType.cat] + color: Literal[PetColor.white] white_infos: str Cat = Annotated[Union[BlackCat, WhiteCat], Field(discriminator='color')] class Dog(BaseModel): - pet_type: Literal['dog'] + pet_type: Literal[PetType.dog] dog_name: str Pet = Annotated[Union[Cat, Dog], Field(discriminator='pet_type')] @@ -2786,87 +3175,142 @@ class Model(BaseModel): number: int assert Model.model_json_schema() == { - 'title': 'Model', - 'type': 'object', + '$defs': { + 'BlackCatWithHeight': { + 'properties': { + 'black_infos': {'title': 'Black ' 'Infos', 'type': 'string'}, + 'color': {'const': 'black', 'title': 'Color'}, + 'info': {'const': 0, 'title': 'Info'}, + 'pet_type': {'const': 'cat', 'title': 'Pet ' 'Type'}, + }, + 'required': ['pet_type', 'color', 'info', 'black_infos'], + 'title': 'BlackCatWithHeight', + 'type': 'object', + }, + 'BlackCatWithWeight': { + 'properties': { + 'black_infos': {'title': 'Black ' 'Infos', 'type': 'string'}, + 'color': {'const': 'black', 'title': 'Color'}, + 'info': {'const': 1, 'title': 'Info'}, + 'pet_type': {'const': 'cat', 'title': 'Pet ' 'Type'}, + }, + 'required': ['pet_type', 'color', 'info', 'black_infos'], + 'title': 'BlackCatWithWeight', + 'type': 'object', + }, + 'Dog': { + 'properties': { + 'dog_name': {'title': 'Dog Name', 'type': 'string'}, + 'pet_type': {'const': 'dog', 'title': 'Pet Type'}, + }, + 'required': ['pet_type', 'dog_name'], + 'title': 'Dog', + 'type': 'object', + }, + 'WhiteCat': { + 'properties': { + 'color': {'const': 'white', 'title': 'Color'}, + 'pet_type': {'const': 'cat', 'title': 'Pet Type'}, + 'white_infos': {'title': 'White Infos', 'type': 'string'}, + }, + 'required': ['pet_type', 'color', 'white_infos'], + 'title': 'WhiteCat', + 'type': 'object', + }, + }, 'properties': { + 'number': {'title': 'Number', 'type': 'integer'}, 'pet': { - 'title': 'Pet', 'discriminator': { - 'propertyName': 'pet_type', 'mapping': { 'cat': { - 'BlackCatWithHeight': {'$ref': '#/$defs/BlackCatWithHeight'}, - 'BlackCatWithWeight': {'$ref': '#/$defs/BlackCatWithWeight'}, - 'WhiteCat': {'$ref': '#/$defs/WhiteCat'}, + 'discriminator': { + 'mapping': { + 'black': { + 'discriminator': { + 'mapping': { + '0': '#/$defs/BlackCatWithHeight', + '1': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + 'white': '#/$defs/WhiteCat', + }, + 'propertyName': 'color', + }, + 'oneOf': [ + { + 'discriminator': { + 'mapping': { + '0': '#/$defs/BlackCatWithHeight', + '1': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + {'$ref': '#/$defs/WhiteCat'}, + ], }, 'dog': '#/$defs/Dog', }, + 'propertyName': 'pet_type', }, 'oneOf': [ { + 'discriminator': { + 'mapping': { + 'black': { + 'discriminator': { + 'mapping': { + '0': '#/$defs/BlackCatWithHeight', + '1': '#/$defs/BlackCatWithWeight', + }, + 'propertyName': 'info', + }, + 'oneOf': [ + {'$ref': '#/$defs/BlackCatWithHeight'}, + {'$ref': '#/$defs/BlackCatWithWeight'}, + ], + }, + 'white': '#/$defs/WhiteCat', + }, + 'propertyName': 'color', + }, 'oneOf': [ { + 'discriminator': { + 'mapping': {'0': '#/$defs/BlackCatWithHeight', '1': '#/$defs/BlackCatWithWeight'}, + 'propertyName': 'info', + }, 'oneOf': [ {'$ref': '#/$defs/BlackCatWithHeight'}, {'$ref': '#/$defs/BlackCatWithWeight'}, - ] + ], }, {'$ref': '#/$defs/WhiteCat'}, - ] + ], }, {'$ref': '#/$defs/Dog'}, ], + 'title': 'Pet', }, - 'number': {'title': 'Number', 'type': 'integer'}, }, 'required': ['pet', 'number'], - '$defs': { - 'BlackCatWithHeight': { - 'title': 'BlackCatWithHeight', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['cat'], 'type': 'string'}, - 'color': {'title': 'Color', 'enum': ['black'], 'type': 'string'}, - 'info': {'title': 'Info', 'enum': ['height'], 'type': 'string'}, - 'black_infos': {'title': 'Black Infos', 'type': 'string'}, - }, - 'required': ['pet_type', 'color', 'info', 'black_infos'], - }, - 'BlackCatWithWeight': { - 'title': 'BlackCatWithWeight', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['cat'], 'type': 'string'}, - 'color': {'title': 'Color', 'enum': ['black'], 'type': 'string'}, - 'info': {'title': 'Info', 'enum': ['weight'], 'type': 'string'}, - 'black_infos': {'title': 'Black Infos', 'type': 'string'}, - }, - 'required': ['pet_type', 'color', 'info', 'black_infos'], - }, - 'WhiteCat': { - 'title': 'WhiteCat', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['cat'], 'type': 'string'}, - 'color': {'title': 'Color', 'enum': ['white'], 'type': 'string'}, - 'white_infos': {'title': 'White Infos', 'type': 'string'}, - }, - 'required': ['pet_type', 'color', 'white_infos'], - }, - 'Dog': { - 'title': 'Dog', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['dog'], 'type': 'string'}, - 'dog_name': {'title': 'Dog Name', 'type': 'string'}, - }, - 'required': ['pet_type', 'dog_name'], - }, - }, + 'title': 'Model', + 'type': 'object', } -@pytest.mark.xfail(reason='working on V2 - discriminated union') +@pytest.mark.xfail(reason='working on V2 - discriminated union with alias') def test_alias_same(): class Cat(BaseModel): pet_type: Literal['cat'] = Field(alias='typeOfPet') @@ -2950,7 +3394,6 @@ class NestedModel: } -@pytest.mark.xfail(reason='working on V2 - discriminated union') def test_discriminated_union_in_list(): class BlackCat(BaseModel): pet_type: Literal['cat'] @@ -2975,65 +3418,69 @@ class Model(BaseModel): n: int assert Model.model_json_schema() == { - 'title': 'Model', - 'type': 'object', + '$defs': { + 'BlackCat': { + 'properties': { + 'black_name': {'title': 'Black Name', 'type': 'string'}, + 'color': {'const': 'black', 'title': 'Color'}, + 'pet_type': {'const': 'cat', 'title': 'Pet Type'}, + }, + 'required': ['pet_type', 'color', 'black_name'], + 'title': 'BlackCat', + 'type': 'object', + }, + 'Dog': { + 'properties': { + 'name': {'title': 'Name', 'type': 'string'}, + 'pet_type': {'const': 'dog', 'title': 'Pet Type'}, + }, + 'required': ['pet_type', 'name'], + 'title': 'Dog', + 'type': 'object', + }, + 'WhiteCat': { + 'properties': { + 'color': {'const': 'white', 'title': 'Color'}, + 'pet_type': {'const': 'cat', 'title': 'Pet Type'}, + 'white_name': {'title': 'White Name', 'type': 'string'}, + }, + 'required': ['pet_type', 'color', 'white_name'], + 'title': 'WhiteCat', + 'type': 'object', + }, + }, 'properties': { + 'n': {'title': 'N', 'type': 'integer'}, 'pets': { - 'title': 'Pets', 'discriminator': { - 'propertyName': 'pet_type', 'mapping': { 'cat': { - 'BlackCat': {'$ref': '#/$defs/BlackCat'}, - 'WhiteCat': {'$ref': '#/$defs/WhiteCat'}, + 'discriminator': { + 'mapping': {'black': '#/$defs/BlackCat', 'white': '#/$defs/WhiteCat'}, + 'propertyName': 'color', + }, + 'oneOf': [{'$ref': '#/$defs/BlackCat'}, {'$ref': '#/$defs/WhiteCat'}], }, 'dog': '#/$defs/Dog', }, + 'propertyName': 'pet_type', }, 'oneOf': [ { - 'oneOf': [ - {'$ref': '#/$defs/BlackCat'}, - {'$ref': '#/$defs/WhiteCat'}, - ], + 'discriminator': { + 'mapping': {'black': '#/$defs/BlackCat', 'white': '#/$defs/WhiteCat'}, + 'propertyName': 'color', + }, + 'oneOf': [{'$ref': '#/$defs/BlackCat'}, {'$ref': '#/$defs/WhiteCat'}], }, {'$ref': '#/$defs/Dog'}, ], + 'title': 'Pets', }, - 'n': {'title': 'N', 'type': 'integer'}, }, 'required': ['pets', 'n'], - '$defs': { - 'BlackCat': { - 'title': 'BlackCat', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['cat'], 'type': 'string'}, - 'color': {'title': 'Color', 'enum': ['black'], 'type': 'string'}, - 'black_name': {'title': 'Black Name', 'type': 'string'}, - }, - 'required': ['pet_type', 'color', 'black_name'], - }, - 'WhiteCat': { - 'title': 'WhiteCat', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['cat'], 'type': 'string'}, - 'color': {'title': 'Color', 'enum': ['white'], 'type': 'string'}, - 'white_name': {'title': 'White Name', 'type': 'string'}, - }, - 'required': ['pet_type', 'color', 'white_name'], - }, - 'Dog': { - 'title': 'Dog', - 'type': 'object', - 'properties': { - 'pet_type': {'title': 'Pet Type', 'enum': ['dog'], 'type': 'string'}, - 'name': {'title': 'Name', 'type': 'string'}, - }, - 'required': ['pet_type', 'name'], - }, - }, + 'title': 'Model', + 'type': 'object', }