Skip to content

Commit

Permalink
fix[python]: fix "assert_series_equal" for specific arrangements of n…
Browse files Browse the repository at this point in the history
…ull (#4683)
  • Loading branch information
alexander-beedie committed Sep 2, 2022
1 parent 0f60051 commit 383fdb7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 5 deletions.
4 changes: 3 additions & 1 deletion py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def _assert_series_inner(
else:
# apply check with tolerance, but only to the known-unequal matches
left, right = left.filter(unequal), right.filter(unequal)
if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
if (((left - right).abs() > (atol + rtol * right.abs())).sum() != 0) or (
(left.is_null() != right.is_null()).any()
):
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)
Expand Down
17 changes: 14 additions & 3 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,9 +1873,20 @@ def test_exp() -> None:

def test_cumulative_eval() -> None:
s = pl.Series("values", [1, 2, 3, 4, 5])
expr = pl.element().first() - pl.element().last() ** 2
expected = pl.Series("values", [None, -3.0, -8.0, -15.0, -24.0])
verify_series_and_expr_api(s, expected, "cumulative_eval", expr)

# evaluate expressions individually
expr1 = pl.element().first()
expr2 = pl.element().last() ** 2

expected1 = pl.Series("values", [1, 1, 1, 1, 1])
expected2 = pl.Series("values", [1.0, 4.0, 9.0, 16.0, 25.0])
verify_series_and_expr_api(s, expected1, "cumulative_eval", expr1)
verify_series_and_expr_api(s, expected2, "cumulative_eval", expr2)

# evaluate combined expressions and validate
expr3 = expr1 - expr2
expected3 = pl.Series("values", [0.0, -3.0, -8.0, -15.0, -24.0])
verify_series_and_expr_api(s, expected3, "cumulative_eval", expr3)


def test_drop_nan_ignore_null_3525() -> None:
Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ def test_compare_series_nans_assert_equal() -> None:
assert_series_equal(srs1, srs2, check_exact=True)


def test_compare_series_nulls_are_equal() -> None:
def test_compare_series_nulls() -> None:
srs1 = pl.Series([1, 2, None])
srs2 = pl.Series([1, 2, None])
assert_series_equal(srs1, srs2)

srs1 = pl.Series([1, 2, 3])
srs2 = pl.Series([1, None, None])
with pytest.raises(AssertionError, match="Value mismatch"):
assert_series_equal(srs1, srs2)
with pytest.raises(AssertionError, match="Exact value mismatch"):
assert_series_equal(srs1, srs2, check_exact=True)


def test_compare_series_value_mismatch_string() -> None:
srs1 = pl.Series(["hello", "no"])
Expand Down

0 comments on commit 383fdb7

Please sign in to comment.