Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/9820.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix comparison of ``dataclasses`` with ``InitVar``.
6 changes: 4 additions & 2 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,10 @@ def _compare_eq_cls(left: Any, right: Any, verbose: int) -> List[str]:
if not has_default_eq(left):
return []
if isdatacls(left):
all_fields = left.__dataclass_fields__
fields_to_check = [field for field, info in all_fields.items() if info.compare]
import dataclasses

all_fields = dataclasses.fields(left)
fields_to_check = [info.name for info in all_fields if info.compare]
elif isattrs(left):
all_fields = left.__attrs_attrs__
fields_to_check = [field.name for field in all_fields if getattr(field, "eq")]
Expand Down
12 changes: 12 additions & 0 deletions testing/example_scripts/dataclasses/test_compare_initvar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dataclasses import dataclass
from dataclasses import InitVar


@dataclass
class Foo:
init_only: InitVar[int]
real_attr: int


def test_demonstrate():
assert Foo(1, 2) == Foo(1, 3)
7 changes: 7 additions & 0 deletions testing/test_assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,13 @@ def test_data_classes_with_custom_eq(self, pytester: Pytester) -> None:
result.assert_outcomes(failed=1, passed=0)
result.stdout.no_re_match_line(".*Differing attributes.*")

def test_data_classes_with_initvar(self, pytester: Pytester) -> None:
p = pytester.copy_example("dataclasses/test_compare_initvar.py")
# issue 9820
result = pytester.runpytest(p, "-vv")
result.assert_outcomes(failed=1, passed=0)
result.stdout.no_re_match_line(".*AttributeError.*")


class TestAssert_reprcompare_attrsclass:
def test_attrs(self) -> None:
Expand Down