From d58286c4b4d3392abf032430de811ad6ca5c4d7c Mon Sep 17 00:00:00 2001 From: Sydney Runkle <54324534+sydney-runkle@users.noreply.github.com> Date: Thu, 1 Feb 2024 07:55:07 -0600 Subject: [PATCH] Fix inheriting annotations in dataclasses (#8679) Co-authored-by: Alex Hall --- pydantic/dataclasses.py | 49 +++++++++++++++++------------- tests/test_dataclasses.py | 63 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 21 deletions(-) diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index bb26540cce..d9c9c903b1 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -93,7 +93,7 @@ def dataclass( @dataclass_transform(field_specifiers=(dataclasses.field, Field)) -def dataclass( +def dataclass( # noqa: C901 _cls: type[_T] | None = None, *, init: Literal[False] = False, @@ -153,26 +153,33 @@ def make_pydantic_fields_compatible(cls: type[Any]) -> None: into `x: int = dataclasses.field(default=pydantic.Field(..., kw_only=True), kw_only=True)` """ - # In Python < 3.9, `__annotations__` might not be present if there are no fields. - # we therefore need to use `getattr` to avoid an `AttributeError`. - for field_name in getattr(cls, '__annotations__', []): - field_value = getattr(cls, field_name, None) - # Process only if this is an instance of `FieldInfo`. - if not isinstance(field_value, FieldInfo): - continue - - # Initialize arguments for the standard `dataclasses.field`. - field_args: dict = {'default': field_value} - - # Handle `kw_only` for Python 3.10+ - if sys.version_info >= (3, 10) and field_value.kw_only: - field_args['kw_only'] = True - - # Set `repr` attribute if it's explicitly specified to be not `True`. - if field_value.repr is not True: - field_args['repr'] = field_value.repr - - setattr(cls, field_name, dataclasses.field(**field_args)) + for annotation_cls in cls.__mro__: + # In Python < 3.9, `__annotations__` might not be present if there are no fields. + # we therefore need to use `getattr` to avoid an `AttributeError`. + annotations = getattr(annotation_cls, '__annotations__', []) + for field_name in annotations: + field_value = getattr(cls, field_name, None) + # Process only if this is an instance of `FieldInfo`. + if not isinstance(field_value, FieldInfo): + continue + + # Initialize arguments for the standard `dataclasses.field`. + field_args: dict = {'default': field_value} + + # Handle `kw_only` for Python 3.10+ + if sys.version_info >= (3, 10) and field_value.kw_only: + field_args['kw_only'] = True + + # Set `repr` attribute if it's explicitly specified to be not `True`. + if field_value.repr is not True: + field_args['repr'] = field_value.repr + + setattr(cls, field_name, dataclasses.field(**field_args)) + # In Python 3.8, dataclasses checks cls.__dict__['__annotations__'] for annotations, + # so we must make sure it's initialized before we add to it. + if cls.__dict__.get('__annotations__') is None: + cls.__annotations__ = {} + cls.__annotations__[field_name] = annotations[field_name] def create_dataclass(cls: type[Any]) -> type[PydanticDataclass]: """Create a Pydantic dataclass from a regular dataclass. diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index fcf635f8b8..dc60bd32fe 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -2759,3 +2759,66 @@ def test_disallow_init_false_and_init_var_true() -> None: @pydantic.dataclasses.dataclass class Foo: bar: str = Field(..., init=False, init_var=True) + + +def test_annotations_valid_for_field_inheritance() -> None: + # testing https://github.com/pydantic/pydantic/issues/8670 + + @pydantic.dataclasses.dataclass() + class A: + a: int = pydantic.dataclasses.Field() + + @pydantic.dataclasses.dataclass() + class B(A): + ... + + assert B.__pydantic_fields__['a'].annotation is int + + assert B(a=1).a == 1 + + +def test_annotations_valid_for_field_inheritance_with_existing_field() -> None: + # variation on testing https://github.com/pydantic/pydantic/issues/8670 + + @pydantic.dataclasses.dataclass() + class A: + a: int = pydantic.dataclasses.Field() + + @pydantic.dataclasses.dataclass() + class B(A): + b: str = pydantic.dataclasses.Field() + + assert B.__pydantic_fields__['a'].annotation is int + assert B.__pydantic_fields__['b'].annotation is str + + b = B(a=1, b='b') + assert b.a == 1 + assert b.b == 'b' + + +def test_annotation_with_double_override() -> None: + @pydantic.dataclasses.dataclass() + class A: + a: int + b: int + c: int = pydantic.dataclasses.Field() + d: int = pydantic.dataclasses.Field() + + # note, the order of fields is different here, as to test that the annotation + # is correctly set on the field no matter the base's default / current class's default + @pydantic.dataclasses.dataclass() + class B(A): + a: str + c: str + b: str = pydantic.dataclasses.Field() + d: str = pydantic.dataclasses.Field() + + @pydantic.dataclasses.dataclass() + class C(B): + ... + + for class_ in [B, C]: + instance = class_(a='a', b='b', c='c', d='d') + for field_name in ['a', 'b', 'c', 'd']: + assert class_.__pydantic_fields__[field_name].annotation is str + assert getattr(instance, field_name) == field_name