Skip to content

Commit

Permalink
Fixed assert_frame_equal and assert_series_equal for NaN values (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Jul 8, 2022
1 parent 0377c80 commit f663838
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 21 deletions.
50 changes: 34 additions & 16 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
is_polars_dtype,
py_type_to_dtype,
)
from polars.internals import DataFrame, LazyFrame, Series, col
from polars.internals import DataFrame, LazyFrame, Series, col, lit

if HYPOTHESIS_INSTALLED:
# TODO: increase the number of iterations during CI checkins?
Expand All @@ -80,6 +80,7 @@ def assert_frame_equal(
check_column_names: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
nans_compare_equal: bool = True,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Expand All @@ -88,19 +89,21 @@ def assert_frame_equal(
Parameters
----------
left
the dataframe to compare
the dataframe to compare.
right
the dataframe to compare with
the dataframe to compare with.
check_dtype
if True, data types need to match exactly
if True, data types need to match exactly.
check_exact
if False, test if values are within tolerance of each other (see `rtol` & `atol`)
if False, test if values are within tolerance of each other (see `rtol` & `atol`).
check_column_names
if True, dataframes must have the same column names in the same order
if True, dataframes must have the same column names in the same order.
rtol
relative tolerance for inexact checking. Fraction of values in `right`
relative tolerance for inexact checking. Fraction of values in `right`.
atol
absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Examples
--------
Expand Down Expand Up @@ -136,7 +139,14 @@ def assert_frame_equal(
# this does not assume a particular order
for c in left.columns:
_assert_series_inner(
left[c], right[c], check_dtype, check_exact, atol, rtol, obj
left[c],
right[c],
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
obj,
)


Expand All @@ -148,26 +158,29 @@ def assert_series_equal(
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
nans_compare_equal: bool = True,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Parameters
----------
left
the series to compare
the series to compare.
right
the series to compare with
the series to compare with.
check_dtype
if True, data types need to match exactly
if True, data types need to match exactly.
check_names
if True, names need to match
if True, names need to match.
check_exact
if False, test if values are within tolerance of each other (see `rtol` & `atol`)
if False, test if values are within tolerance of each other (see `rtol` & `atol`).
rtol
relative tolerance for inexact checking. Fraction of values in `right`
relative tolerance for inexact checking. Fraction of values in `right`.
atol
absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Examples
--------
Expand All @@ -187,14 +200,17 @@ def assert_series_equal(
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_assert_series_inner(left, right, check_dtype, check_exact, atol, rtol, obj)
_assert_series_inner(
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol, obj
)


def _assert_series_inner(
left: Series,
right: Series,
check_dtype: bool,
check_exact: bool,
nans_compare_equal: bool,
atol: float,
rtol: float,
obj: str,
Expand All @@ -208,13 +224,15 @@ def _assert_series_inner(
can_be_subtracted = False

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean

if check_dtype:
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

# create mask of which (if any) values are unequal
unequal = left != right
if unequal.any() and nans_compare_equal and left.dtype in (Float32, Float64):
# handle NaN values (which compare unequal to themselves)
unequal = unequal & ~((left.is_nan() & right.is_nan()).fill_null(lit(False)))

# assert exact, or with tolerance
if unequal.any():
Expand Down
32 changes: 32 additions & 0 deletions py-polars/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,20 @@ def test_compare_series_empty_equal() -> None:
assert_series_equal(srs1, srs2)


def test_compare_series_nans_assert_equal() -> None:
# NaN values do not _compare_ equal, but should _assert_ as equal here
nan = float("NaN")

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, 2.0, nan])
assert_series_equal(srs1, srs2)

srs1 = pl.Series([1.0, 2.0, nan])
srs2 = pl.Series([1.0, nan, 3.0])
with pytest.raises(AssertionError):
assert_series_equal(srs1, srs2, check_exact=True)


def test_compare_series_nulls_are_equal() -> None:
srs1 = pl.Series([1, 2, None])
srs2 = pl.Series([1, 2, None])
Expand Down Expand Up @@ -66,6 +80,24 @@ def test_compare_series_value_exact_mismatch() -> None:
assert_series_equal(srs1, srs2, check_exact=True)


def test_compare_frame_equal_nans() -> None:
# NaN values do not _compare_ equal, but should _assert_ as equal here
nan = float("NaN")

df1 = pl.DataFrame(
data={"x": [1.0, nan], "y": [nan, 2.0]},
columns=[("x", pl.Float32), ("y", pl.Float64)],
)
assert_frame_equal(df1, df1, check_exact=True)

df2 = pl.DataFrame(
data={"x": [1.0, nan], "y": [None, 2.0]},
columns=[("x", pl.Float32), ("y", pl.Float64)],
)
with pytest.raises(AssertionError):
assert_frame_equal(df1, df2, check_exact=True)


def test_assert_frame_equal_pass() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2]})
Expand Down
9 changes: 5 additions & 4 deletions py-polars/tests_parametric/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# -------------------------------------------------
# Validate Series behaviour with parameteric tests
# -------------------------------------------------
# ----------------------------------------------------
# Validate DataFrame behaviour with parameteric tests
# ----------------------------------------------------
from hypothesis import example, given, settings
from hypothesis.strategies import integers

import polars as pl
from polars.testing import column, dataframes
from polars.testing import assert_frame_equal, column, dataframes


@given(df=dataframes())
def test_repr(df: pl.DataFrame) -> None:
assert isinstance(repr(df), str)
assert_frame_equal(df, df, check_exact=True)
# print(df)


Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests_parametric/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hypothesis.strategies import sampled_from

import polars as pl
from polars.testing import series # , verify_series_and_expr_api
from polars.testing import assert_series_equal, series # , verify_series_and_expr_api

# # TODO: exclude obvious/known overflow inside the strategy before commenting back in
# @given(s=series(allowed_dtypes=_NUMERIC_COL_TYPES, name="a"))
Expand Down Expand Up @@ -38,3 +38,4 @@ def test_series_slice(
sliced_pl_data = srs[s].to_list()

assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed"
assert_series_equal(srs, srs, check_exact=True)

0 comments on commit f663838

Please sign in to comment.