Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discriminated union type breaks from 2.6.3 to 2.6.4: TypeError: 'none' is not a valid discriminated union variant; should be a BaseModel or dataclass #9118

Closed
1 task done
mcleantom opened this issue Mar 27, 2024 · 5 comments · Fixed by #9124
Labels
bug V2 Bug related to Pydantic V2 pending Awaiting a response / confirmation

Comments

@mcleantom
Copy link

mcleantom commented Mar 27, 2024

Initial Checks

  • I confirm that I'm using Pydantic V2

Description

I have some pydantic models that I am using to make configurable models based on conditions, they are quite complicated being self-referencing and generic. My code works in 2.6.3 but breaks in 2.6.4.

In 2.6.4 I get the error:

  File "C:\Users\tom.mclean\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\scratch_6.py", line 102, in <module>
    BooCondition = create_conditions_type(Foo)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Roaming\JetBrains\PyCharmCE2023.3\scratches\scratch_6.py", line 87, in create_conditions_type
    AndCondition[return_value_type],
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\main.py", line 687, in __class_getitem__
    submodel = _generics.create_generic_submodel(model_name, origin, args, params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 138, in create_generic_submodel
    created_model = meta(
                    ^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_model_construction.py", line 178, in __new__
    set_model_fields(cls, bases, config_wrapper, types_namespace)
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_model_construction.py", line 452, in set_model_fields
    fields, class_vars = collect_model_fields(cls, bases, config_wrapper, types_namespace, typevars_map=typevars_map)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_fields.py", line 232, in collect_model_fields
    field.apply_typevars_map(typevars_map, types_namespace)
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\fields.py", line 556, in apply_typevars_map
    self.annotation = _generics.replace_types(annotation, typevars_map)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 290, in replace_types
    resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 290, in <genexpr>
    resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 282, in replace_types
    annotated = replace_types(annotated_type, type_map)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 290, in replace_types
    resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 290, in <genexpr>
    resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 323, in replace_types
    return type_[resolved_type_args]
           ~~~~~^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\main.py", line 687, in __class_getitem__
    submodel = _generics.create_generic_submodel(model_name, origin, args, params)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generics.py", line 138, in create_generic_submodel
    created_model = meta(
                    ^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_model_construction.py", line 183, in __new__
    complete_model_class(
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_model_construction.py", line 527, in complete_model_class
    schema = gen_schema.clean_schema(schema)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_generate_schema.py", line 427, in clean_schema
    schema = _discriminated_union.apply_discriminators(schema)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 56, in apply_discriminators
    return simplify_schema_references(_core_utils.walk_core_schema(schema, inner))
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 416, in walk_core_schema
    return f(schema.copy(), _dispatch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 44, in inner
    s = recurse(s, inner)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 199, in walk
    return f(schema, self._walk)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 44, in inner
    s = recurse(s, inner)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 202, in _walk
    schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 234, in handle_definitions_schema
    updated_definition = self.walk(definition, f)
                         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 199, in walk
    return f(schema, self._walk)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 44, in inner
    s = recurse(s, inner)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 202, in _walk
    schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 211, in _handle_other_schemas
    schema['schema'] = self.walk(sub_schema, f)  # type: ignore
                       ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 199, in walk
    return f(schema, self._walk)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 44, in inner
    s = recurse(s, inner)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 202, in _walk
    schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 339, in handle_model_fields_schema
    replaced_field['schema'] = self.walk(v['schema'], f)
                               ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 199, in walk
    return f(schema, self._walk)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 44, in inner
    s = recurse(s, inner)
        ^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 202, in _walk
    schema = self._schema_type_to_method[schema['type']](schema.copy(), f)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 254, in handle_list_schema
    schema['items_schema'] = self.walk(items_schema, f)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_core_utils.py", line 199, in walk
    return f(schema, self._walk)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 53, in inner
    s = apply_discriminator(s, discriminator, definitions)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 95, in apply_discriminator
    return _ApplyInferredDiscriminator(discriminator, definitions or {}).apply(schema)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 190, in apply
    schema = self._apply_to_root(schema)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 230, in _apply_to_root
    self._handle_choice(choice)
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 305, in _handle_choice
    inferred_discriminator_values = self._infer_discriminator_values_for_choice(choice, source_name=None)
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 384, in _infer_discriminator_values_for_choice
    return self._infer_discriminator_values_for_choice(self.definitions[schema_ref], source_name=source_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\tom.mclean\AppData\Local\miniconda3\envs\Module\Lib\site-packages\pydantic\_internal\_discriminated_union.py", line 386, in _infer_discriminator_values_for_choice
    raise TypeError(
TypeError: 'none' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`

Example Code

from __future__ import annotations

from pydantic import BaseModel, Field
from typing import Annotated, Union, Literal, Generic, TypeVar, Optional
from abc import ABC, abstractmethod


ReturnValueType = TypeVar("ReturnValueType")


class BaseCondition(BaseModel, ABC, Generic[ReturnValueType]):
    type: str
    value: Optional[ReturnValueType] = None

    @abstractmethod
    def passes(self, values: dict[str, float]) -> bool:
        pass


class IsInIntervalCondition(BaseCondition[ReturnValueType]):
    type: Literal["interval"] = "interval"
    key: str
    gt: Optional[float] = None
    lt: Optional[float] = None
    gte: Optional[float] = None
    lte: Optional[float] = None

    def passes(self, value: dict[str, float]) -> bool:
        value = value[self.key]
        if self.gt is not None and value <= self.gt:
            return False
        if self.lt is not None and value >= self.lt:
            return False
        if self.gte is not None and value < self.gte:
            return False
        if self.lte is not None and value > self.lte:
            return False
        return True


class InSetCondition(BaseCondition[ReturnValueType]):
    type: Literal["set"] = "set"
    values: list[float]
    key: str

    def passes(self, values: dict[str, float]) -> bool:
        if self.key not in values:
            return False
        return values[self.key] in values


class AndCondition(BaseCondition[ReturnValueType]):
    type: Literal["and"] = "and"
    conditions: list[Conditions]

    def passes(self, values: dict[str, float]):
        for condition in self.conditions:
            if not condition.passes(values):
                return False
        return True


class OrCondition(BaseCondition[ReturnValueType]):
    type: Literal["or"] = "or"
    conditions: list[Conditions]

    def passes(self, value: dict[str, float]) -> bool:
        for i in self.conditions:
            if i.passes(value):
                return True
        return False


class AlwaysCondition(BaseCondition[ReturnValueType]):
    type: Literal["always"] = "always"

    def passes(self, value: dict[str, float]) -> bool:
        return True


def create_conditions_type(return_value_type) -> type:
    return Annotated[
        Union[
            IsInIntervalCondition[return_value_type],
            InSetCondition[return_value_type],
            AndCondition[return_value_type],
            OrCondition[return_value_type],
            AlwaysCondition[return_value_type]
        ],
        Field(discriminator="type")
    ]


Conditions = create_conditions_type(ReturnValueType)


class Foo(BaseModel):
    x: str


FooCondition = create_conditions_type(Foo)


class Bar(BaseModel):
   foo_condition: FooCondition

Python, Pydantic & OS Version

pydantic version: 2.6.4
        pydantic-core version: 2.16.3
          pydantic-core build: profile=release pgo=true
                 install path: C:\Users\tom.mclean\AppData\Local\miniconda3\envs\PitchLights\Lib\site-packages\pydantic
               python version: 3.11.7 | packaged by Anaconda, Inc. | (main, Dec 15 2023, 18:05:47) [MSC v.1916 64 bit (AMD64)]
                     platform: Windows-10-10.0.19045-SP0
             related packages: typing_extensions-4.9.0
                       commit: unknown
@mcleantom mcleantom added bug V2 Bug related to Pydantic V2 pending Awaiting a response / confirmation labels Mar 27, 2024
@sydney-runkle
Copy link
Member

@mcleantom,

Thanks a ton for reporting this. We'll look into a fix for 2.7!

@sydney-runkle
Copy link
Member

So I'll note that without the generic complexity, this simplified version works:

from __future__ import annotations


from typing import Annotated, Literal, List, Union
from pydantic import BaseModel, Field

class Dog(BaseModel):
    type_: Literal['dog']
    friends: List[Pet]

class Cat(BaseModel):
    type_: Literal['cat']
    friends: List[Pet]

Pet = Annotated[Union[Dog, Cat], Field(..., discriminator='type_')]

I'm working on creating an MRE with the generic complexity.

@sydney-runkle
Copy link
Member

Ah here we go:

from __future__ import annotations


from typing import Annotated, Literal, List, Union, TypeVar, Generic
from pydantic import BaseModel, Field, TypeAdapter

T = TypeVar('T')

class Dog(BaseModel, Generic[T]):
    type_: Literal['dog']
    friends: List[GenericPet]
    id: T

class Cat(BaseModel, Generic[T]):
    type_: Literal['cat']
    friends: List[GenericPet]
    id: T

GenericPet = Annotated[Union[Dog[T], Cat[T]], Field(..., discriminator='type_')]
ta = TypeAdapter(Dog[int])
#> TypeError: 'none' is not a valid discriminated union variant; should be a `BaseModel` or `dataclass`

@sydney-runkle
Copy link
Member

I've opened a PR with a fix - this will be released in 2.7 soon (beginning of next week)

@mcleantom
Copy link
Author

I've opened a PR with a fix - this will be released in 2.7 soon (beginning of next week)

Amazing! Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug V2 Bug related to Pydantic V2 pending Awaiting a response / confirmation
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants