From a825106c2c5f2599643e1ad6d4caa1bfc6e1e999 Mon Sep 17 00:00:00 2001 From: Eric Jolibois Date: Thu, 29 Dec 2022 11:18:00 +0100 Subject: [PATCH] fix: avoid multiple calls of `__post_init__` when dataclasses are inherited (#4493) --- changes/4487-PrettyWood.md | 1 + pydantic/dataclasses.py | 5 ++++- tests/test_dataclasses.py | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) create mode 100644 changes/4487-PrettyWood.md diff --git a/changes/4487-PrettyWood.md b/changes/4487-PrettyWood.md new file mode 100644 index 0000000000..8f18dd0c50 --- /dev/null +++ b/changes/4487-PrettyWood.md @@ -0,0 +1 @@ +fix: avoid multiple calls of `__post_init__` when dataclasses are inherited diff --git a/pydantic/dataclasses.py b/pydantic/dataclasses.py index 00e413a666..1856a1203c 100644 --- a/pydantic/dataclasses.py +++ b/pydantic/dataclasses.py @@ -288,7 +288,10 @@ def handle_extra_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: init(self, *args, **kwargs) if hasattr(dc_cls, '__post_init__'): - post_init = dc_cls.__post_init__ + try: + post_init = dc_cls.__post_init__.__wrapped__ # type: ignore[attr-defined] + except AttributeError: + post_init = dc_cls.__post_init__ @wraps(post_init) def new_post_init(self: 'Dataclass', *args: Any, **kwargs: Any) -> None: diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 9d7d55595a..65151801cc 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -1474,6 +1474,31 @@ def __post_init__(self): assert C().a == 6 # 1 * 3 + 3 +def test_inheritance_post_init_2(): + post_init_calls = 0 + post_init_post_parse_calls = 0 + + @pydantic.dataclasses.dataclass + class BaseClass: + def __post_init__(self): + nonlocal post_init_calls + post_init_calls += 1 + + @pydantic.dataclasses.dataclass + class AbstractClass(BaseClass): + pass + + @pydantic.dataclasses.dataclass + class ConcreteClass(AbstractClass): + def __post_init_post_parse__(self): + nonlocal post_init_post_parse_calls + post_init_post_parse_calls += 1 + + ConcreteClass() + assert post_init_calls == 1 + assert post_init_post_parse_calls == 1 + + def test_dataclass_setattr(): class Foo: bar: str = 'cat'