Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): address several edge-cases found when asserting NaN equality #5732

Merged
merged 1 commit into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,14 @@ def _assert_series_inner(
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

# confirm that we can call 'is_nan' on both sides
left_is_float = left.dtype in (Float32, Float64)
right_is_float = right.dtype in (Float32, Float64)
comparing_float_dtypes = left_is_float and right_is_float

# create mask of which (if any) values are unequal
unequal = left != right
if unequal.any() and nans_compare_equal and left.dtype in (Float32, Float64):
if unequal.any() and nans_compare_equal and comparing_float_dtypes:
# handle NaN values (which compare unequal to themselves)
unequal = unequal & ~(
(left.is_nan() & right.is_nan()).fill_null(pli.lit(False))
Expand All @@ -182,13 +187,25 @@ def _assert_series_inner(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance, but only to the known-unequal matches
# apply check with tolerance (to the known-unequal matches).
left, right = left.filter(unequal), right.filter(unequal)
mismatch, nan_info = False, ""
if (((left - right).abs() > (atol + rtol * right.abs())).sum() != 0) or (
(left.is_null() != right.is_null()).any()
):
left.is_null() != right.is_null()
).any():
mismatch = True
elif comparing_float_dtypes:
# note: take special care with NaN values.
if not nans_compare_equal and (left.is_nan() == right.is_nan()).any():
nan_info = " (nans_compare_equal=False)"
mismatch = True
elif (left.is_nan() != right.is_nan()).any():
nan_info = f" (nans_compare_equal={nans_compare_equal})"
mismatch = True

if mismatch:
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
obj, f"Value mismatch{nan_info}", left=list(left), right=list(right)
)


Expand Down
46 changes: 39 additions & 7 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,49 @@ def test_compare_series_empty_equal() -> None:


def test_compare_series_nans_assert_equal() -> None:
# NaN values do not _compare_ equal, but should _assert_ as equal here
# note: NaN values do not _compare_ equal, but should _assert_ equal (by default)
nan = float("NaN")

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, 2.0, nan])
assert_series_equal(srs1, srs2)
srs1 = pl.Series([1.0, 2.0, nan, 4.0, None, 6.0])
srs2 = pl.Series([1.0, nan, 3.0, 4.0, None, 6.0])
srs3 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])

for srs in (srs1, srs2, srs3):
assert_series_equal(srs, srs)
assert_series_equal(srs, srs, check_exact=True)

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, nan, 3.0])
with pytest.raises(AssertionError):
assert_series_equal(srs1, srs2, check_exact=True)
assert_series_equal(srs1, srs1, nans_compare_equal=False)
with pytest.raises(AssertionError):
assert_series_equal(srs1, srs1, nans_compare_equal=False, check_exact=True)

for check_exact, nans_equal in (
(False, False),
(False, True),
(True, False),
(True, True),
):
if check_exact:
check_msg = "Exact value mismatch"
else:
check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}"

with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
srs1, srs2, check_exact=check_exact, nans_compare_equal=nans_equal
)
with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
srs1, srs3, check_exact=check_exact, nans_compare_equal=nans_equal
)

srs4 = pl.Series([1.0, 2.0, 3.0, 4.0, None, 6.0])
srs5 = pl.Series([1.0, 2.0, 3.0, 4.0, nan, 6.0])
srs6 = pl.Series([1, 2, 3, 4, None, 6])

assert_series_equal(srs4, srs6, check_dtype=False)
with pytest.raises(AssertionError):
assert_series_equal(srs5, srs6, check_dtype=False)


def test_compare_series_nulls() -> None:
Expand Down