-
-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Subclass resolution #7366
Comments
Hi! Good question. I think something like the below solution should fit your needs. This solution allows Pydantic's I also included a few additional examples at the bottom. Let me know if you have any questions :) from typing import List, Union
from pydantic import BaseModel, Field
class Animal(BaseModel):
name: str
class DomesticAnimal(Animal):
owner_name: str
class Shelter(BaseModel):
animals: List[Union[DomesticAnimal, Animal]]
cat = DomesticAnimal(name="Simon", owner_name="Freddy")
lion = Animal(name="Fluffy")
shelter_from_model = Shelter(animals=[cat])
print(shelter_from_model)
# animals=[DomesticAnimal(name='Simon', owner_name='Freddy')]
shelter_from_model_dump = Shelter(animals=[cat.model_dump()])
print(shelter_from_model_dump)
# animals=[DomesticAnimal(name='Simon', owner_name='Freddy')]
shelter_from_animal_model = Shelter(animals=[lion])
print(shelter_from_animal_model)
# animals=[Animal(name='Fluffy')]
shelter_from_animal_model_dump = Shelter(animals=[lion.model_dump()])
print(shelter_from_animal_model_dump)
# animals=[Animal(name='Fluffy')] |
@sydney-runkle thanks for the answer ! I could use What if I have multiple subclasses with the same fields but different set of methods or different implementation of an abstract method ? class Cat(DomesticAnimal):
def say_hi(self):
print('Mew')
class Dog(DomesticAnimal):
def say_hi(self):
print('Woof') Or use of Wouldn't it be great to be able to have something like: class Shelter(BaseModel):
animals: List[Resolved[Animal]] that would inject an |
Hi @mdelmans, thanks for your follow up. Using I know you mentioned that you designed a custom validator / model resolver to fix this issue. Here's a similar approach. The value of this approach is that you're not required to list out all of the available subclasses of This does require that all subclasses have a field called from typing import ClassVar, Any, Literal, List, Union
from pydantic_core.core_schema import ValidatorFunctionWrapHandler
from typing_extensions import Annotated
from pydantic import BaseModel, model_validator, TypeAdapter, Field
class Animal(BaseModel):
_subclasses: ClassVar[dict[str, type[Any]]] = {}
_discriminating_type_adapter: ClassVar[TypeAdapter]
@model_validator(mode='wrap')
@classmethod
def _parse_into_subclass(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> 'Animal':
if cls is Animal:
return Animal._discriminating_type_adapter.validate_python(v)
return handler(v)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# This approach requires all subclasses have a field called 'kind' to be used as a discriminator
Animal._subclasses[cls.model_fields['kind'].default] = cls
# The following will create a new type adapter every time a new subclass is created,
# which is fine if there aren't that many classes (as far as performance goes)
Animal._discriminating_type_adapter = TypeAdapter(
Annotated[Union[tuple(Animal._subclasses.values())], Field(discriminator='kind')])
class Dog(Animal):
kind: Literal['dog'] = 'dog'
name: str
class Cat(Animal):
kind: Literal['cat'] = 'cat'
name: str
class Shelter(BaseModel):
animals: List[Animal]
my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")
print(repr(my_dog))
# > Dog(kind='dog', name='buddy')
print(my_dog.model_dump())
# > {'kind': 'dog', 'name': 'buddy'}
my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(kind='dog', name='buddy'), Cat(kind='cat', name='fluffy')])
another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(kind='dog', name='buddy'), Cat(kind='cat', name='fluffy')])
# As a side note, this throws an error, as a Fish subclass of Animal has not been defined.
not_so_good_shelter = Shelter(animals=[{'kind': 'fish', 'name': 'slimy'}]) Do you have any ideas for how to cleanly implement this type of model resolution logic into Pydantic? I think a model validator based approach like this one should suit your general needs, though we're certainly open to hearing other ideas. |
@sydney-runkle like how you are using Made a small improvement (?) which does not require to define from typing import ClassVar, Any, List, Union, Optional
from pydantic_core.core_schema import ValidatorFunctionWrapHandler
from typing_extensions import Annotated
from pydantic import BaseModel, model_validator, TypeAdapter, Field, field_validator
def kind_identifier(cls):
return cls.__name__
class Animal(BaseModel):
_subclasses: ClassVar[dict[str, type[Any]]] = {}
_discriminating_type_adapter: ClassVar[TypeAdapter]
kind: Annotated[Optional[str], Field(validate_default=True)] = None
@field_validator("kind")
@classmethod
def set_kind(cls, v: Any):
identifier = kind_identifier(cls)
if v is None:
v = identifier
elif v != identifier:
raise ValueError(f"Wrong type, given: {v}, expected: {identifier}")
return v
@model_validator(mode='wrap')
@classmethod
def _parse_into_subclass(cls, v: Any, handler: ValidatorFunctionWrapHandler) -> 'Animal':
if cls is Animal:
return Animal._discriminating_type_adapter.validate_python(v)
return handler(v)
@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
# This approach requires all subclasses have a field called 'kind' to be used as a discriminator
Animal._subclasses[kind_identifier(cls)] = cls
# The following will create a new type adapter every time a new subclass is created,
# which is fine if there aren't that many classes (as far as performance goes)
Animal._discriminating_type_adapter = TypeAdapter(
Union[tuple(Animal._subclasses.values())])
class Dog(Animal):
name: str
class Cat(Animal):
name: str
class Shelter(BaseModel):
animals: List[Animal]
my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")
print(repr(my_dog))
# > Dog(kind='Dog', name='buddy')
print(my_dog.model_dump())
# > {'kind': 'Dog', 'name': 'buddy'}
my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(kind='Dog', name='buddy'), Cat(kind='Cat', name='fluffy')])
another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(kind='Dog', name='buddy'), Cat(kind='Cat', name='fluffy')]) But I think both solutions don't solve the whole problem. The example below will not behave as expected for both implementations. class Cat(Animal):
kind: Literal['cat'] = 'cat'
name: str
class Lion(Cat):
kind: Literal['lion'] = 'lion'
class CatShelter(BaseModel):
cats: List[Cat]
my_cat = Cat(name="fluffy")
my_lion = Lion(name="simba")
cat_shelter = CatShelter(cats=[my_cat.model_dump(), my_lion.model_dump()])
# >cats.1.kind
# Input should be 'cat' [type=literal_error, input_value='lion', input_type=str] I guess the problem in |
Hi @mdelmans, Apologies for the delay. I see what you mean regarding the problem with the second implementation. I think that using a metaclass would make things more complicated than needed. We have an open PR adding support for functional discriminators (#6915) as well as this open issue (#7462), both of which I think will help to solve this problem. I'd love to circle back to this specific use case while we develop said features next week to make sure this use case is supported by the new features. |
Hi, we are migrating to pydantic v2 and we have the same use case. I came up with a slightly different solution (see below). To be honest I'm less confident than before (with v1) in the code regarding our logic to conserve the type. BTW the new feature for discriminators using callable looks great but I don't yet see how this could simplify the code for this given use case 🤔 . Would be super cool if there was an out-of-the box feature to do that. For example we could pass from pydantic import BaseModel, model_serializer, model_validator, ValidatorFunctionWrapHandler
from typing import Any, Dict, List, Type, TypeVar
T = TypeVar('T')
def get_subclasses_recursive(cls: Type[T]) -> List[Type[T]]:
"""
Returns all the subclasses of a given class.
"""
subclasses = []
for subclass in cls.__subclasses__():
subclasses.append(subclass)
subclasses.extend(get_subclasses_recursive(subclass))
return subclasses
def get_subclass_recursive(cls: Type[T], name: str, allow_same_class: bool = False) -> Type[T]:
# I oversimplified this to keep it short (there are checks for 0 or more than 1 subclasses
# and we did not even use parameter `allow_same_class` to also match the parent class)
return next(c for c in get_subclasses_recursive(cls=cls) if c.__name__ == name)
class TypeConservingModel(BaseModel):
"""
Preserves the types of objects passed in pydantic models during serialization and de-serialization.
This is achieved by injecting a field called "type" upon serialization.
"""
@model_serializer(mode='wrap')
def inject_type_on_serialization(self, handler: ValidatorFunctionWrapHandler) -> Dict[str, Any]:
result: Dict[str, Any] = handler(self)
if 'type' in result:
raise ValueError('Cannot use field "type". It is reserved.')
result['type'] = f'{self.__class__.__name__}'
return result
@model_validator(mode='wrap') # noqa # the decorator position is correct
@classmethod
def retrieve_type_on_deserialization(cls, value: Any,
handler: ValidatorFunctionWrapHandler) -> 'TypeConservingModel':
if isinstance(value, dict):
# WARNING: we do not want to modify `value` which will come from the outer scope
# WARNING2: `sub_cls(**modified_value)` will trigger a recursion, and thus we need to remove `type`
modified_value = value.copy()
sub_cls_name = modified_value.pop('type', None)
if sub_cls_name is not None:
sub_cls = get_subclass_recursive(cls=TypeConservingModel, name=sub_cls_name, allow_same_class=True)
return sub_cls(**modified_value)
else:
return handler(value)
return handler(value) Demonstration using the previous example with animalsclass Animal(TypeConservingModel):
pass
class Dog(Animal):
name: str
class Cat(Animal):
name: str
class Shelter(BaseModel):
animals: List[Animal]
my_dog = Dog(name="buddy")
my_cat = Cat(name="fluffy")
print(repr(my_dog))
# > Dog(name='buddy')
print(my_dog.model_dump())
# > {'name': 'buddy', 'type': 'Dog'}
my_shelter = Shelter(animals=[my_dog, my_cat])
print(repr(my_shelter))
# > Shelter(animals=[Dog(name='buddy'), Cat(name='fluffy')])
another_shelter = Shelter(animals=[my_dog.model_dump(), my_cat.model_dump()])
print(repr(another_shelter))
# > Shelter(animals=[Dog(name='buddy'), Cat(name='fluffy')]) |
Initial Checks
Description
Pydantic allows subclassing and handles it well when we validate Python objects. For example, in the snippet below,
Shelter
will understand that theDomesticAnimal
is a subclass ofAnimal
and will allow it in the validator.Output:
However, if we serialise
cat
before passing it toShelter
, it will strip its class toAnimal
and will not include extra fields.Output:
And if we add
extra="forbid"
on theAnimal
class the last example will fail the validation altogether, althoughcat
is a perfectAnimal
in OOP sense.I managed to get around it by adding an extra
type
field to the base class and writing a custom validator / model resolver that convert the serialised data into the right class based on thee value in thetype
field.Is there a place for such feature in V2 or is it out of scope ?
Affected Components
.model_dump()
and.model_dump_json()
model_construct()
, pickling, private attributes, ORM modeThe text was updated successfully, but these errors were encountered: