Skip to content

Commit

Permalink
Parse default values in dataclass attributes correctly (#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielNoord committed Sep 6, 2022
1 parent beb6da7 commit 3331d62
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 7 deletions.
2 changes: 2 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ What's New in astroid 2.12.8?
=============================
Release date: TBA

* Fixed parsing of default values in ``dataclass`` attributes.

Closes PyCQA/pylint#7425

What's New in astroid 2.12.7?
=============================
Expand Down
5 changes: 1 addition & 4 deletions astroid/brain/brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,7 @@ def _generate_dataclass_init(
prev_kw_only = ""
if base_init and base.is_dataclass:
# Skip the self argument and check for duplicate arguments
all_arguments = base_init.args.format_args()[6:].split(", ")
arguments = ", ".join(
i for i in all_arguments if i.split(":")[0] not in assign_names
)
arguments = base_init.args.format_args(skippable_names=assign_names)[6:]
try:
prev_pos_only, prev_kw_only = arguments.split("*, ")
except ValueError:
Expand Down
17 changes: 14 additions & 3 deletions astroid/nodes/node_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def arguments(self):
"""Get all the arguments for this node, including positional only and positional and keyword"""
return list(itertools.chain((self.posonlyargs or ()), self.args or ()))

def format_args(self):
def format_args(self, *, skippable_names: set[str] | None = None) -> str:
"""Get the arguments formatted as string.
:returns: The formatted arguments.
Expand All @@ -804,6 +804,7 @@ def format_args(self):
self.posonlyargs,
positional_only_defaults,
self.posonlyargs_annotations,
skippable_names=skippable_names,
)
)
result.append("/")
Expand All @@ -813,6 +814,7 @@ def format_args(self):
self.args,
positional_or_keyword_defaults,
getattr(self, "annotations", None),
skippable_names=skippable_names,
)
)
if self.vararg:
Expand All @@ -822,7 +824,10 @@ def format_args(self):
result.append("*")
result.append(
_format_args(
self.kwonlyargs, self.kw_defaults, self.kwonlyargs_annotations
self.kwonlyargs,
self.kw_defaults,
self.kwonlyargs_annotations,
skippable_names=skippable_names,
)
)
if self.kwarg:
Expand Down Expand Up @@ -929,7 +934,11 @@ def _find_arg(argname, args, rec=False):
return None, None


def _format_args(args, defaults=None, annotations=None):
def _format_args(
args, defaults=None, annotations=None, skippable_names: set[str] | None = None
) -> str:
if skippable_names is None:
skippable_names = set()
values = []
if args is None:
return ""
Expand All @@ -939,6 +948,8 @@ def _format_args(args, defaults=None, annotations=None):
default_offset = len(args) - len(defaults)
packed = itertools.zip_longest(args, annotations)
for i, (arg, annotation) in enumerate(packed):
if arg.name in skippable_names:
continue
if isinstance(arg, Tuple):
values.append(f"({_format_args(arg.elts)})")
else:
Expand Down
41 changes: 41 additions & 0 deletions tests/unittest_brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,44 @@ class MyDataclass(Unknown):
)

assert next(node.infer())


def test_dataclass_with_default_factory() -> None:
"""Regression test for dataclasses with default values.
Reported in https://github.com/PyCQA/pylint/issues/7425
"""
bad_node, good_node = astroid.extract_node(
"""
from dataclasses import dataclass
from typing import Union
@dataclass
class BadExampleParentClass:
xyz: Union[str, int]
@dataclass
class BadExampleClass(BadExampleParentClass):
xyz: str = ""
BadExampleClass.__init__ #@
@dataclass
class GoodExampleParentClass:
xyz: str
@dataclass
class GoodExampleClass(GoodExampleParentClass):
xyz: str = ""
GoodExampleClass.__init__ #@
"""
)

bad_init: bases.UnboundMethod = next(bad_node.infer())
assert bad_init.args.defaults
assert [a.name for a in bad_init.args.args] == ["self", "xyz"]

good_init: bases.UnboundMethod = next(good_node.infer())
assert bad_init.args.defaults
assert [a.name for a in good_init.args.args] == ["self", "xyz"]

0 comments on commit 3331d62

Please sign in to comment.