Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix __post_init__ cause infinite recursion in inheritance #606

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
History
-------

v0.30 (unreleased)
..................
* fix infinite recursion with dataclass inheritance and ``__post_init__``, #606 by @Hanaasagi

v0.29 (2019-06-19)
..................
* support dataclasses.InitVar, #592 by @pfrederiks
Expand Down
24 changes: 10 additions & 14 deletions pydantic/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@

class DataclassType:
__pydantic_model__: Type[BaseModel]
__post_init_original__: Callable[..., None]
__post_init_post_parse__: Callable[..., None]
__initialised__: bool

def __init__(self, *args: Any, **kwargs: Any) -> None:
Expand All @@ -25,16 +23,6 @@ def __validate__(cls, v: Any) -> 'DataclassType':
pass


def _pydantic_post_init(self: 'DataclassType', *initvars: Any) -> None:
if self.__post_init_original__:
self.__post_init_original__(*initvars)
d = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)[0]
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if self.__post_init_post_parse__:
self.__post_init_post_parse__()


def _validate_dataclass(cls: Type['DataclassType'], v: Any) -> 'DataclassType':
if isinstance(v, cls):
return v
Expand Down Expand Up @@ -75,15 +63,23 @@ def _process_class(
post_init_post_parse = getattr(_cls, '__post_init_post_parse__', None)
if post_init_original and post_init_original.__name__ == '_pydantic_post_init':
post_init_original = None

def _pydantic_post_init(self: 'DataclassType', *initvars: Any) -> None:
if post_init_original is not None:
post_init_original(self, *initvars)
d = validate_model(self.__pydantic_model__, self.__dict__, cls=self.__class__)[0]
object.__setattr__(self, '__dict__', d)
object.__setattr__(self, '__initialised__', True)
if post_init_post_parse is not None:
post_init_post_parse(self)

_cls.__post_init__ = _pydantic_post_init
cls = dataclasses._process_class(_cls, init, repr, eq, order, unsafe_hash, frozen) # type: ignore

fields: Dict[str, Any] = {
field.name: (field.type, field.default if field.default != dataclasses.MISSING else Required)
for field in dataclasses.fields(cls)
}
cls.__post_init_original__ = post_init_original
cls.__post_init_post_parse__ = post_init_post_parse

validators = gather_validators(cls)
cls.__pydantic_model__ = create_model(
Expand Down
28 changes: 28 additions & 0 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,34 @@ def __post_init__(self):
assert post_init_called


def test_post_init_inheritance_chain():
parent_post_init_called = False
post_init_called = False

@pydantic.dataclasses.dataclass
class ParentDataclass:
a: int

def __post_init__(self):
nonlocal parent_post_init_called
parent_post_init_called = True

@pydantic.dataclasses.dataclass
class MyDataclass(ParentDataclass):
b: int

def __post_init__(self):
super().__post_init__()
nonlocal post_init_called
post_init_called = True

d = MyDataclass(a=1, b=2)
assert d.a == 1
assert d.b == 2
assert parent_post_init_called
assert post_init_called


def test_post_init_post_parse():
post_init_post_parse_called = False

Expand Down