Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions integration_tests/test_fix_dataclass_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from codemodder.codemods.test import (
BaseIntegrationTest,
original_and_expected_from_code_path,
)
from core_codemods.fix_dataclass_defaults import FixDataclassDefaults


class TestFixDataclassDefaults(BaseIntegrationTest):
codemod = FixDataclassDefaults
code_path = "tests/samples/fix_dataclass_defaults.py"
original_code, expected_new_code = original_and_expected_from_code_path(
code_path,
[
(0, """from dataclasses import field, dataclass\n"""),
(5, """ phones: list = field(default_factory=list)\n"""),
(6, """ friends: dict = field(default_factory=dict)\n"""),
(7, """ family: set = field(default_factory=set)\n"""),
],
)

# fmt: off
expected_diff =(
"""--- \n"""
"""+++ \n"""
"""@@ -1,8 +1,8 @@\n"""
"""-from dataclasses import dataclass\n"""
"""+from dataclasses import field, dataclass\n"""
""" \n"""
""" @dataclass\n"""
""" class Test:\n"""
""" name: str = ""\n"""
"""- phones: list = []\n"""
"""- friends: dict = {}\n"""
"""- family: set = set()\n"""
"""+ phones: list = field(default_factory=list)\n"""
"""+ friends: dict = field(default_factory=dict)\n"""
"""+ family: set = field(default_factory=set)\n"""
)
# fmt: on

expected_line_change = "6"
change_description = FixDataclassDefaults.change_description
num_changes = 3
13 changes: 13 additions & 0 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,19 @@ def find_immediate_function_def(
break
return first_fdef

def find_immediate_class_def(self, node: cst.CSTNode) -> Optional[cst.ClassDef]:
"""
Find if node is inside a class definition. In case of nested classes, it returns the most immediate one.
"""
# We disregard nested classes, we consider only the immediate one
ancestors = self.path_to_root(node)
first_cdef = None
for ancestor in ancestors:
if isinstance(ancestor, cst.ClassDef):
first_cdef = ancestor
break
return first_cdef

def path_to_root(self, node: cst.CSTNode) -> list[cst.CSTNode]:
"""
Returns node's path to root. Includes self.
Expand Down
4 changes: 4 additions & 0 deletions src/codemodder/scripts/generate_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ class DocMetadata:
importance="High",
guidance_explained="This change may impact performance in some cases, but it is recommended when handling untrusted data.",
),
"fix-dataclass-defaults": DocMetadata(
importance="Medium",
guidance_explained="This change is safe and will prevent runtime `ValueError`.",
),
}

METADATA = CORE_METADATA | {
Expand Down
2 changes: 2 additions & 0 deletions src/core_codemods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .file_resource_leak import FileResourceLeak
from .fix_assert_tuple import FixAssertTuple
from .fix_async_task_instantiation import FixAsyncTaskInstantiation
from .fix_dataclass_defaults import FixDataclassDefaults
from .fix_deprecated_abstractproperty import FixDeprecatedAbstractproperty
from .fix_deprecated_logging_warn import FixDeprecatedLoggingWarn
from .fix_empty_sequence_comparison import FixEmptySequenceComparison
Expand Down Expand Up @@ -125,6 +126,7 @@
FixAsyncTaskInstantiation,
DjangoModelWithoutDunderStr,
TransformFixHasattrCall,
FixDataclassDefaults,
],
)

Expand Down
18 changes: 18 additions & 0 deletions src/core_codemods/docs/pixee_python_fix-dataclass-defaults.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
When defining a Python dataclass it is not safe to use mutable datatypes (such as `list`, `dict`, or `set`) as defaults for the attributes. This is because the defined attribute will be shared by all instances of the dataclass type. Using such a mutable default will ultimately result in a `ValueError` at runtime. This codemod updates attributes of `dataclasses.dataclass` with mutable defaults to use `dataclasses.field` instead. The [dataclass documentation](https://docs.python.org/3/library/dataclasses.html#mutable-default-values) providesmore details about why using `field(default_factory=...)` is the recommended pattern.

Our changes look something like this:

```diff
-from dataclasses import dataclass
+from dataclasses import field, dataclass

@dataclass
class Person:
name: str = ""
- phones: list = []
- friends: dict = {}
- family: set = set()
+ phones: list = field(default_factory=list)
+ friends: dict = field(default_factory=dict)
+ family: set = field(default_factory=set)
```
60 changes: 60 additions & 0 deletions src/core_codemods/fix_dataclass_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import libcst as cst

from codemodder.codemods.base_visitor import UtilsMixin
from codemodder.codemods.utils_mixin import NameAndAncestorResolutionMixin
from core_codemods.api import Metadata, Reference, ReviewGuidance, SimpleCodemod


class FixDataclassDefaults(SimpleCodemod, NameAndAncestorResolutionMixin, UtilsMixin):
metadata = Metadata(
name="fix-dataclass-defaults",
summary="Replace `dataclass` Mutable Default Values with Call to `field`",
review_guidance=ReviewGuidance.MERGE_WITHOUT_REVIEW,
references=[
Reference(
url="https://docs.python.org/3/library/dataclasses.html#mutable-default-values"
)
],
)
change_description = (
"Replace `dataclass` mutable default values with call to `field`"
)

def leave_AnnAssign(
self, original_node: cst.Assign, updated_node: cst.Assign
) -> cst.CSTNode:
if not self.filter_by_path_includes_or_excludes(
self.node_position(original_node)
):
return updated_node

maybe_classdef = self.find_immediate_class_def(original_node)
maybe_has_decorator = (
self._has_dataclass_decorator(maybe_classdef) if maybe_classdef else False
)
if not maybe_has_decorator:
return updated_node

match original_node.value:
# TODO: add support for populated elements
case cst.List(elements=[]) | cst.Dict(elements=[]) | cst.Tuple(elements=[]):
self.add_needed_import("dataclasses", "field")
self.report_change(original_node)
return updated_node.with_changes(
value=cst.parse_expression(
f"field(default_factory={ type(original_node.value).__name__.lower()})"
)
)
case cst.Call(func=cst.Name(value="set"), args=[]):
self.add_needed_import("dataclasses", "field")
self.report_change(original_node)
return updated_node.with_changes(
value=cst.parse_expression("field(default_factory=set)")
)
return updated_node

def _has_dataclass_decorator(self, node: cst.ClassDef) -> bool:
for decorator in node.decorators:
if self.find_base_name(decorator.decorator) == "dataclasses.dataclass":
return True
return False
132 changes: 132 additions & 0 deletions tests/codemods/test_fix_dataclass_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest

from codemodder.codemods.test import BaseCodemodTest
from core_codemods.fix_dataclass_defaults import FixDataclassDefaults


class TestFixDataclassDefaults(BaseCodemodTest):
codemod = FixDataclassDefaults

def test_name(self):
assert self.codemod.name == "fix-dataclass-defaults"

def test_import(self, tmpdir):
input_code = """
import dataclasses

@dataclasses.dataclass
class Test:
name: str = ""
phones: list = []
friends: dict = {} # I collect friends as I go :)
family: set = set()
"""
expected = """
import dataclasses
from dataclasses import field

@dataclasses.dataclass
class Test:
name: str = ""
phones: list = field(default_factory=list)
friends: dict = field(default_factory=dict) # I collect friends as I go :)
family: set = field(default_factory=set)
"""
self.run_and_assert(tmpdir, input_code, expected, num_changes=3)

def test_import_from(self, tmpdir):
input_code = """
from dataclasses import dataclass

@dataclass
class Test:
name: str = ""
phones: list = []
friends: dict = {} # I collect friends as I go :)
family: set = set()
"""
expected = """
from dataclasses import field, dataclass

@dataclass
class Test:
name: str = ""
phones: list = field(default_factory=list)
friends: dict = field(default_factory=dict) # I collect friends as I go :)
family: set = field(default_factory=set)
"""
self.run_and_assert(tmpdir, input_code, expected, num_changes=3)

def test_populated_defaults(self, tmpdir):
# TODO: support later using lambda.
input_code = """
import dataclasses

@dataclasses.dataclass
class Test:
name: str = ""
phones: list = [1, 2, 3]
friends: dict = {"friend": "one"}
family: set = set((1, 2, 3))
"""
self.run_and_assert(tmpdir, input_code, input_code)

@pytest.mark.parametrize(
"code",
[
"""
class Test:
name: str = ""
phones: list = []
friends: dict = {} # I collect friends as I go :)
family: set = set()
""",
""""
from dataclasses import dataclass

@dataclass
class Test:
name: str
last = ""
nums: list
friends = []
family = None
""",
""""
from dataclasses import dataclass, field

class Timer:
pass

@dataclass
class Test:
name: str
last_name: str = None
address: str = ""
friends: tuple = ()
nums: list[int] = field(default_factory=list)
timer: Timer = field(default_factory=Timer)
family: set = () # says set, actually a tuple
""",
],
)
def test_no_change(self, tmpdir, code):
self.run_and_assert(tmpdir, code, code)

def test_exclude_line(self, tmpdir):
input_code = (
expected
) = """
import dataclasses

@dataclasses.dataclass
class Test:
phones: list = []
"""
lines_to_exclude = [6]
self.run_and_assert(
tmpdir,
input_code,
expected,
lines_to_exclude=lines_to_exclude,
)
8 changes: 8 additions & 0 deletions tests/samples/fix_dataclass_defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

@dataclass
class Test:
name: str = ""
phones: list = []
friends: dict = {}
family: set = set()