Skip to content

Commit

Permalink
fix: support generic models with discriminated union (#3551)
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed Dec 24, 2021
1 parent edad0db commit e882277
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pydantic/fields.py
Expand Up @@ -733,6 +733,10 @@ def prepare_discriminated_union_sub_fields(self) -> None:
Note that this process can be aborted if a `ForwardRef` is encountered
"""
assert self.discriminator_key is not None

if self.type_.__class__ is DeferredType:
return

assert self.sub_fields is not None
sub_fields_mapping: Dict[str, 'ModelField'] = {}
all_aliases: Set[str] = set()
Expand Down
37 changes: 36 additions & 1 deletion tests/test_discrimated_union.py
@@ -1,12 +1,14 @@
import re
import sys
from enum import Enum
from typing import Union
from typing import Generic, TypeVar, Union

import pytest
from typing_extensions import Annotated, Literal

from pydantic import BaseModel, Field, ValidationError
from pydantic.errors import ConfigError
from pydantic.generics import GenericModel


def test_discriminated_union_only_union():
Expand Down Expand Up @@ -361,3 +363,36 @@ class Model(BaseModel):
n: int

assert isinstance(Model(**{'pet': {'pet_type': 'dog', 'name': 'Milou'}, 'n': 5}).pet, Dog)


@pytest.mark.skipif(sys.version_info < (3, 7), reason='generics only supported for python 3.7 and above')
def test_generic():
T = TypeVar('T')

class Success(GenericModel, Generic[T]):
type: Literal['Success'] = 'Success'
data: T

class Failure(BaseModel):
type: Literal['Failure'] = 'Failure'
error_message: str

class Container(GenericModel, Generic[T]):
result: Union[Success[T], Failure] = Field(discriminator='type')

with pytest.raises(ValidationError, match="Discriminator 'type' is missing in value"):
Container[str].parse_obj({'result': {}})

with pytest.raises(
ValidationError,
match=re.escape("No match for discriminator 'type' and value 'Other' (allowed values: 'Success', 'Failure')"),
):
Container[str].parse_obj({'result': {'type': 'Other'}})

with pytest.raises(
ValidationError, match=re.escape('Container[str]\nresult -> Success[str] -> data\n field required')
):
Container[str].parse_obj({'result': {'type': 'Success'}})

# coercion is done properly
assert Container[str].parse_obj({'result': {'type': 'Success', 'data': 1}}).result.data == '1'

0 comments on commit e882277

Please sign in to comment.