Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subclass resolution #7366

Closed
5 of 13 tasks
mdelmans opened this issue Sep 7, 2023 · 6 comments
Closed
5 of 13 tasks

Subclass resolution #7366

mdelmans opened this issue Sep 7, 2023 · 6 comments
Assignees
Labels

Comments

@mdelmans
Copy link

mdelmans commented Sep 7, 2023

Initial Checks

  • I have searched Google & GitHub for similar requests and couldn't find anything
  • I have read and followed the docs and still think this feature is missing

Description

Pydantic allows subclassing and handles it well when we validate Python objects. For example, in the snippet below, Shelter will understand that the DomesticAnimal is a subclass of Animal and will allow it in the validator.

from typing import List
from pydantic import BaseModel


class Animal(BaseModel):
    name: str


class DomesticAnimal(Animal):
    owner_name: str


class Shelter(BaseModel):
    animals: List[Animal]

cat = DomesticAnimal(name="Simon", owner_name="Freddy")
shelter = Shelter(animals=[cat])
print(shelter)

Output:

animals=[DomesticAnimal(name='Simon', owner_name='Freddy')]

However, if we serialise cat before passing it to Shelter, it will strip its class to Animal and will not include extra fields.

shelter = Shelter(animals=[cat.model_dump()])

Output:

animals=[Animal(name='Simon')]

And if we add extra="forbid" on the Animal class the last example will fail the validation altogether, although cat is a perfect Animal in OOP sense.

I managed to get around it by adding an extra type field to the base class and writing a custom validator / model resolver that convert the serialised data into the right class based on thee value in the type field.

Is there a place for such feature in V2 or is it out of scope ?

Affected Components

@pydantic-hooky pydantic-hooky bot added the unconfirmed Bug not yet confirmed as valid/applicable label Sep 7, 2023
@sydney-runkle
Copy link
Member

Hi! Good question.

I think something like the below solution should fit your needs. This solution allows Pydantic's Union parsing logic to handle the pattern you used your extra type field for.

I also included a few additional examples at the bottom. Let me know if you have any questions :)

from typing import List, Union
from pydantic import BaseModel, Field


class Animal(BaseModel):
    name: str


class DomesticAnimal(Animal):
    owner_name: str


class Shelter(BaseModel):
    animals: List[Union[DomesticAnimal, Animal]]

cat = DomesticAnimal(name="Simon", owner_name="Freddy")
lion = Animal(name="Fluffy")

shelter_from_model = Shelter(animals=[cat])
print(shelter_from_model)
# animals=[DomesticAnimal(name='Simon', owner_name='Freddy')]

shelter_from_model_dump = Shelter(animals=[cat.model_dump()])
print(shelter_from_model_dump)
# animals=[DomesticAnimal(name='Simon', owner_name='Freddy')]

shelter_from_animal_model = Shelter(animals=[lion])
print(shelter_from_animal_model)
# animals=[Animal(name='Fluffy')]

shelter_from_animal_model_dump = Shelter(animals=[lion.model_dump()])
print(shelter_from_animal_model_dump)
# animals=[Animal(name='Fluffy')]

@sydney-runkle sydney-runkle assigned sydney-runkle and unassigned lig Sep 7, 2023
@sydney-runkle sydney-runkle added question and removed feature request unconfirmed Bug not yet confirmed as valid/applicable labels Sep 7, 2023
@mdelmans
Copy link
Author

mdelmans commented Sep 8, 2023

@sydney-runkle thanks for the answer !

I could use Union but then I would rely on the uniqueness in the set of fields to distinguish between the subclasses.

What if I have multiple subclasses with the same fields but different set of methods or different implementation of an abstract method ?

class Cat(DomesticAnimal):
    def say_hi(self):
        print('Mew')

class Dog(DomesticAnimal):
    def say_hi(self):
        print('Woof')

Or use of BaseModel for OOP is not encouraged ?


Wouldn't it be great to be able to have something like:

class Shelter(BaseModel):
    animals: List[Resolved[Animal]]

that would inject an type: Optional[str] = None into the Animal model and use it to resolve to the specific subclass ?

@sydney-runkle
Copy link
Member

sydney-runkle commented Sep 8, 2023

Hi @mdelmans, thanks for your follow up. Using BaseModel for OOP can be great, so no worries there.

I know you mentioned that you designed a custom validator / model resolver to fix this issue. Here's a similar approach. The value of this approach is that you're not required to list out all of the available subclasses of Animal, as they're cataloged in the __init_subclass__ method of the Animal class.

This does require that all subclasses have a field called kind that can be used as a discriminator for the annotated Union. In some way or another, Pydantic needs to be informed of the possible types allowed for a given field.

from typing import ClassVar, Any, Literal, List, Union

from pydantic_core.core_schema import ValidatorFunctionWrapHandler
from typing_extensions import Annotated

from pydantic import BaseModel, model_validator, TypeAdapter, Field


class Animal(BaseModel):
    _subclasses: ClassVar[dict[str, type[Any]]] = {}
    _discriminating_type_adapter: ClassVar[TypeAdapter]

    @model_validator(mode='wrap')
    @classmethod
    def _parse_into_subclass(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> 'Animal':
        if cls is Animal:
            return Animal._discriminating_type_adapter.validate_python(v)
        return handler(v)

    @classmethod
    def __pydantic_init_subclass__(cls, **kwargs):
        # This approach requires all subclasses have a field called 'kind' to be used as a discriminator
        Animal._subclasses[cls.model_fields['kind'].default] = cls

        # The following will create a new type adapter every time a new subclass is created,
        # which is fine if there aren't that many classes (as far as performance goes)
        Animal._discriminating_type_adapter = TypeAdapter(
            Annotated[Union[tuple(Animal._subclasses.values())], Field(discriminator='kind')])


class Dog(Animal):
    kind: Literal['dog'] = 'dog'
    name: str


class Cat(Animal):
    kind: Literal['cat'] = 'cat'
    name: str


class Shelter(BaseModel):
    animals: List[Animal]


my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")

print(repr(my_dog))
# > Dog(kind='dog', name='buddy')
print(my_dog.model_dump())
# > {'kind': 'dog', 'name': 'buddy'}

my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(kind='dog', name='buddy'), Cat(kind='cat', name='fluffy')])

another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(kind='dog', name='buddy'), Cat(kind='cat', name='fluffy')])

# As a side note, this throws an error, as a Fish subclass of Animal has not been defined.
not_so_good_shelter = Shelter(animals=[{'kind': 'fish', 'name': 'slimy'}])

Do you have any ideas for how to cleanly implement this type of model resolution logic into Pydantic? I think a model validator based approach like this one should suit your general needs, though we're certainly open to hearing other ideas.

@mdelmans
Copy link
Author

@sydney-runkle like how you are using __pydantic_init_subclass__ !

Made a small improvement (?) which does not require to define kind explicitly

from typing import ClassVar, Any, List, Union, Optional

from pydantic_core.core_schema import ValidatorFunctionWrapHandler
from typing_extensions import Annotated

from pydantic import BaseModel, model_validator, TypeAdapter, Field, field_validator


def kind_identifier(cls):
    return cls.__name__


class Animal(BaseModel):
    _subclasses: ClassVar[dict[str, type[Any]]] = {}
    _discriminating_type_adapter: ClassVar[TypeAdapter]

    kind: Annotated[Optional[str], Field(validate_default=True)] = None

    @field_validator("kind")
    @classmethod
    def set_kind(cls, v: Any):
        identifier = kind_identifier(cls)
        if v is None:
            v = identifier
        elif v != identifier:
            raise ValueError(f"Wrong type, given: {v}, expected: {identifier}")
        return v

    @model_validator(mode='wrap')
    @classmethod
    def _parse_into_subclass(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> 'Animal':
        if cls is Animal:
            return Animal._discriminating_type_adapter.validate_python(v)
        return handler(v)

    @classmethod
    def __pydantic_init_subclass__(cls, **kwargs):
        # This approach requires all subclasses have a field called 'kind' to be used as a discriminator
        Animal._subclasses[kind_identifier(cls)] = cls

        # The following will create a new type adapter every time a new subclass is created,
        # which is fine if there aren't that many classes (as far as performance goes)
        Animal._discriminating_type_adapter = TypeAdapter(
            Union[tuple(Animal._subclasses.values())])


class Dog(Animal):
    name: str


class Cat(Animal):
    name: str


class Shelter(BaseModel):
    animals: List[Animal]


my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")


print(repr(my_dog))
# > Dog(kind='Dog', name='buddy')
print(my_dog.model_dump())
# > {'kind': 'Dog', 'name': 'buddy'}

my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(kind='Dog', name='buddy'), Cat(kind='Cat', name='fluffy')])

another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(kind='Dog', name='buddy'), Cat(kind='Cat', name='fluffy')])

But I think both solutions don't solve the whole problem. The example below will not behave as expected for both implementations.

class Cat(Animal):
    kind: Literal['cat'] = 'cat'
    name: str


class Lion(Cat):
    kind: Literal['lion'] = 'lion'


class CatShelter(BaseModel):
    cats: List[Cat]


my_cat = Cat(name="fluffy")
my_lion = Lion(name="simba")

cat_shelter = CatShelter(cats=[my_cat.model_dump(), my_lion.model_dump()])
# >cats.1.kind
#      Input should be 'cat' [type=literal_error, input_value='lion', input_type=str]

I guess the problem in if cls is Animal: line. Now we also want to apply discriminating type for Cat. Do we need a new metaclass ?

@sydney-runkle
Copy link
Member

Hi @mdelmans,

Apologies for the delay. I see what you mean regarding the problem with the second implementation. I think that using a metaclass would make things more complicated than needed.

We have an open PR adding support for functional discriminators (#6915) as well as this open issue (#7462), both of which I think will help to solve this problem.

I'd love to circle back to this specific use case while we develop said features next week to make sure this use case is supported by the new features.

@thibaultbetremieux
Copy link

thibaultbetremieux commented Oct 2, 2023

Hi, we are migrating to pydantic v2 and we have the same use case.

I came up with a slightly different solution (see below). To be honest I'm less confident than before (with v1) in the code regarding our logic to conserve the type.

BTW the new feature for discriminators using callable looks great but I don't yet see how this could simplify the code for this given use case 🤔 . Would be super cool if there was an out-of-the box feature to do that. For example we could pass class_discriminator_field='kind' to Animal via subclass parameters.

from pydantic import BaseModel, model_serializer, model_validator, ValidatorFunctionWrapHandler
from typing import Any, Dict, List, Type, TypeVar

T = TypeVar('T')


def get_subclasses_recursive(cls: Type[T]) -> List[Type[T]]:
    """
    Returns all the subclasses of a given class.
    """
    subclasses = []
    for subclass in cls.__subclasses__():
        subclasses.append(subclass)
        subclasses.extend(get_subclasses_recursive(subclass))
    return subclasses


def get_subclass_recursive(cls: Type[T], name: str, allow_same_class: bool = False) -> Type[T]:
    # I oversimplified this to keep it short (there are checks for 0 or more than 1 subclasses
    # and we did not even use parameter `allow_same_class` to also match the parent class) 
    return next(c for c in get_subclasses_recursive(cls=cls) if c.__name__ == name)


class TypeConservingModel(BaseModel):
    """
    Preserves the types of objects passed in pydantic models during serialization and de-serialization.
    This is achieved by injecting a field called "type" upon serialization.
    """

    @model_serializer(mode='wrap')
    def inject_type_on_serialization(self, handler: ValidatorFunctionWrapHandler) -> Dict[str, Any]:
        result: Dict[str, Any] = handler(self)
        if 'type' in result:
            raise ValueError('Cannot use field "type". It is reserved.')
        result['type'] = f'{self.__class__.__name__}'
        return result

    @model_validator(mode='wrap')  # noqa  # the decorator position is correct
    @classmethod
    def retrieve_type_on_deserialization(cls, value: Any,
                                         handler: ValidatorFunctionWrapHandler) -> 'TypeConservingModel':
        if isinstance(value, dict):
            # WARNING: we do not want to modify `value` which will come from the outer scope
            # WARNING2: `sub_cls(**modified_value)` will trigger a recursion, and thus we need to remove `type`
            modified_value = value.copy()
            sub_cls_name = modified_value.pop('type', None)
            if sub_cls_name is not None:
                sub_cls = get_subclass_recursive(cls=TypeConservingModel, name=sub_cls_name, allow_same_class=True)
                return sub_cls(**modified_value)
            else:
                return handler(value)
        return handler(value)
Demonstration using the previous example with animals
class Animal(TypeConservingModel):
    pass


class Dog(Animal):
    name: str


class Cat(Animal):
    name: str


class Shelter(BaseModel):
    animals: List[Animal]


my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")


print(repr(my_dog))
# > Dog(name='buddy')
print(my_dog.model_dump())
# > {'name': 'buddy', 'type': 'Dog'}

my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(name='buddy'), Cat(name='fluffy')])

another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(name='buddy'), Cat(name='fluffy')])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants