Skip to content

Commit

Permalink
feat(python): Improve assert_frame_equal messages (#5962)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Dec 31, 2022
1 parent 5895424 commit b062858
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
21 changes: 12 additions & 9 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,25 @@ def assert_frame_equal(
"""
if isinstance(left, pli.LazyFrame) and isinstance(right, pli.LazyFrame):
left, right = left.collect(), right.collect()
obj = "pli.LazyFrame"
obj = "LazyFrames"
else:
obj = "pli.DataFrame"
obj = "DataFrames"

if not (isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
elif left.shape[0] != right.shape[0]:
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape)

# this assumes we want it in the same order
union_cols = list(set(left.columns).union(set(right.columns)))
for c in union_cols:
if c not in right.columns:
raise AssertionError(f"column {c} in left frame, but not in right")
if c not in left.columns:
raise AssertionError(f"column {c} in right frame, but not in left")
left_not_right = [c for c in left.columns if c not in right.columns]
if left_not_right:
raise AssertionError(
f"Columns {left_not_right} in left frame, but not in right"
)
right_not_left = [c for c in right.columns if c not in left.columns]
if right_not_left:
raise AssertionError(
f"Columns {right_not_left} in right frame, but not in left"
)

if check_column_names:
if left.columns != right.columns:
Expand Down
24 changes: 17 additions & 7 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def test_compare_frame_equal_nans() -> None:
data={"x": [1.0, nan], "y": [None, 2.0]},
columns=[("x", pl.Float32), ("y", pl.Float64)],
)
with pytest.raises(AssertionError):
with pytest.raises(
AssertionError, match="DataFrames are different\n\nExact value mismatch"
):
assert_frame_equal(df1, df2, check_exact=True)


Expand All @@ -148,35 +150,43 @@ def test_assert_frame_equal_pass() -> None:
def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(AssertionError):
with pytest.raises(
AssertionError, match="DataFrames are different\n\nType mismatch"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]


def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(AssertionError):
with pytest.raises(
AssertionError, match="DataFrames are different\n\nLength mismatch"
):
assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"b": [1, 2]})
with pytest.raises(AssertionError):
with pytest.raises(
AssertionError, match="Columns \\['a'\\] in left frame, but not in right"
):
assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch2() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(AssertionError):
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]})
with pytest.raises(
AssertionError, match="Columns \\['b', 'c'\\] in right frame, but not in left"
):
assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch_order() -> None:
df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(AssertionError):
with pytest.raises(AssertionError, match="Columns are not in the same order"):
assert_frame_equal(df1, df2)
assert_frame_equal(df1, df2, check_column_names=False)

Expand Down

0 comments on commit b062858

Please sign in to comment.