Skip to content

Commit

Permalink
Handle dataclass kw_only keyword correctly (#1764)
Browse files Browse the repository at this point in the history
Co-authored-by: Jacob Walls <jacobtylerwalls@gmail.com>
  • Loading branch information
2 people authored and Pierre-Sassoulas committed Sep 5, 2022
1 parent 8f8448e commit 1f5dc45
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 36 deletions.
4 changes: 4 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ Release date: TBA

Closes PyCQA/pylint#7375

* The ``dataclass`` brain now understands the ``kw_only`` keyword in dataclass decorators.

Closes PyCQA/pylint#7290


What's New in astroid 2.12.5?
=============================
Expand Down
101 changes: 67 additions & 34 deletions astroid/brain/brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@
from astroid import bases, context, helpers, inference_tip
from astroid.builder import parse
from astroid.const import PY39_PLUS, PY310_PLUS
from astroid.exceptions import (
AstroidSyntaxError,
InferenceError,
MroError,
UseInferenceDefault,
)
from astroid.exceptions import AstroidSyntaxError, InferenceError, UseInferenceDefault
from astroid.manager import AstroidManager
from astroid.nodes.node_classes import (
AnnAssign,
Expand Down Expand Up @@ -89,21 +84,22 @@ def dataclass_transform(node: ClassDef) -> None:
if not _check_generate_dataclass_init(node):
return

try:
reversed_mro = list(reversed(node.mro()))
except MroError:
reversed_mro = [node]

field_assigns = {}
field_order = []
for klass in (k for k in reversed_mro if is_decorated_with_dataclass(k)):
for assign_node in _get_dataclass_attributes(klass, init=True):
name = assign_node.target.name
if name not in field_assigns:
field_order.append(name)
field_assigns[name] = assign_node

init_str = _generate_dataclass_init([field_assigns[name] for name in field_order])
kw_only_decorated = False
if PY310_PLUS and node.decorators.nodes:
for decorator in node.decorators.nodes:
if not isinstance(decorator, Call):
kw_only_decorated = False
break
for keyword in decorator.keywords:
if keyword.arg == "kw_only":
kw_only_decorated = keyword.value.bool_value()

init_str = _generate_dataclass_init(
node,
list(_get_dataclass_attributes(node, init=True)),
kw_only_decorated,
)

try:
init_node = parse(init_str)["__init__"]
except AstroidSyntaxError:
Expand Down Expand Up @@ -179,22 +175,24 @@ def _check_generate_dataclass_init(node: ClassDef) -> bool:
return True

# Check for keyword arguments of the form init=False
return all(
keyword.arg != "init"
and keyword.value.bool_value() # type: ignore[union-attr] # value is never None
return not any(
keyword.arg == "init"
and not keyword.value.bool_value() # type: ignore[union-attr] # value is never None
for keyword in found.keywords
)


def _generate_dataclass_init(assigns: list[AnnAssign]) -> str:
def _generate_dataclass_init(
node: ClassDef, assigns: list[AnnAssign], kw_only_decorated: bool
) -> str:
"""Return an init method for a dataclass given the targets."""
target_names = []
params = []
assignments = []
params: list[str] = []
assignments: list[str] = []
assign_names: list[str] = []

for assign in assigns:
name, annotation, value = assign.target.name, assign.annotation, assign.value
target_names.append(name)
assign_names.append(name)

if _is_init_var(annotation): # type: ignore[arg-type] # annotation is never None
init_var = True
Expand All @@ -208,10 +206,7 @@ def _generate_dataclass_init(assigns: list[AnnAssign]) -> str:
init_var = False
assignment_str = f"self.{name} = {name}"

if annotation:
param_str = f"{name}: {annotation.as_string()}"
else:
param_str = name
param_str = f"{name}: {annotation.as_string()}"

if value:
if isinstance(value, Call) and _looks_like_dataclass_field_call(
Expand All @@ -235,7 +230,45 @@ def _generate_dataclass_init(assigns: list[AnnAssign]) -> str:
if not init_var:
assignments.append(assignment_str)

params_string = ", ".join(["self"] + params)
try:
base: ClassDef = next(next(iter(node.bases)).infer())
base_init: FunctionDef | None = base.locals["__init__"][0]
except (StopIteration, InferenceError, KeyError):
base_init = None

prev_pos_only = ""
prev_kw_only = ""
if base_init and base.is_dataclass:
# Skip the self argument and check for duplicate arguments
all_arguments = base_init.args.format_args()[6:].split(", ")
arguments = ", ".join(
i for i in all_arguments if i.split(":")[0] not in assign_names
)
try:
prev_pos_only, prev_kw_only = arguments.split("*, ")
except ValueError:
prev_pos_only, prev_kw_only = arguments, ""

if prev_pos_only and not prev_pos_only.endswith(", "):
prev_pos_only += ", "

# Construct the new init method paramter string
params_string = "self, "
if prev_pos_only:
params_string += prev_pos_only
if not kw_only_decorated:
params_string += ", ".join(params)

if not params_string.endswith(", "):
params_string += ", "

if prev_kw_only:
params_string += "*, " + prev_kw_only + ", "
if kw_only_decorated:
params_string += ", ".join(params) + ", "
elif kw_only_decorated:
params_string += "*, " + ", ".join(params) + ", "

assignments_string = "\n ".join(assignments) if assignments else "pass"
return f"def __init__({params_string}) -> None:\n {assignments_string}"

Expand Down
81 changes: 79 additions & 2 deletions tests/unittest_brain_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,12 +625,12 @@ class B(A):
"""
)
init = next(node.infer())
assert [a.name for a in init.args.args] == ["self", "arg0", "arg2", "arg1"]
assert [a.name for a in init.args.args] == ["self", "arg0", "arg1", "arg2"]
assert [a.as_string() if a else None for a in init.args.annotations] == [
None,
"float",
"list", # not str
"int",
"list", # not str
]


Expand Down Expand Up @@ -747,3 +747,80 @@ class B:

init = next(node_two.infer())
assert [a.name for a in init.args.args] == expected


def test_kw_only_decorator() -> None:
"""Test that we update the signature correctly based on the keyword.
kw_only was introduced in PY310.
"""
foodef, bardef, cee, dee = astroid.extract_node(
"""
from dataclasses import dataclass
@dataclass(kw_only=True)
class Foo:
a: int
e: str
@dataclass(kw_only=False)
class Bar(Foo):
c: int
@dataclass(kw_only=False)
class Cee(Bar):
d: int
@dataclass(kw_only=True)
class Dee(Cee):
ee: int
Foo.__init__ #@
Bar.__init__ #@
Cee.__init__ #@
Dee.__init__ #@
"""
)

foo_init: bases.UnboundMethod = next(foodef.infer())
if PY310_PLUS:
assert [a.name for a in foo_init.args.args] == ["self"]
assert [a.name for a in foo_init.args.kwonlyargs] == ["a", "e"]
else:
assert [a.name for a in foo_init.args.args] == ["self", "a", "e"]
assert [a.name for a in foo_init.args.kwonlyargs] == []

bar_init: bases.UnboundMethod = next(bardef.infer())
if PY310_PLUS:
assert [a.name for a in bar_init.args.args] == ["self", "c"]
assert [a.name for a in bar_init.args.kwonlyargs] == ["a", "e"]
else:
assert [a.name for a in bar_init.args.args] == ["self", "a", "e", "c"]
assert [a.name for a in bar_init.args.kwonlyargs] == []

cee_init: bases.UnboundMethod = next(cee.infer())
if PY310_PLUS:
assert [a.name for a in cee_init.args.args] == ["self", "c", "d"]
assert [a.name for a in cee_init.args.kwonlyargs] == ["a", "e"]
else:
assert [a.name for a in cee_init.args.args] == ["self", "a", "e", "c", "d"]
assert [a.name for a in cee_init.args.kwonlyargs] == []

dee_init: bases.UnboundMethod = next(dee.infer())
if PY310_PLUS:
assert [a.name for a in dee_init.args.args] == ["self", "c", "d"]
assert [a.name for a in dee_init.args.kwonlyargs] == ["a", "e", "ee"]
else:
assert [a.name for a in dee_init.args.args] == [
"self",
"a",
"e",
"c",
"d",
"ee",
]
assert [a.name for a in dee_init.args.kwonlyargs] == []

0 comments on commit 1f5dc45

Please sign in to comment.