Skip to content

Commit

Permalink
Make frame/series asserts more resilient against integer overflow (#3850
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alexander-beedie committed Jun 29, 2022
1 parent 539043a commit e3d1d52
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 14 deletions.
30 changes: 17 additions & 13 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
)
from hypothesis.strategies._internal.utils import defines_strategy

HYPOTHIS_INSTALLED = True
HYPOTHESIS_INSTALLED = True
except ImportError:
HYPOTHIS_INSTALLED = False
HYPOTHESIS_INSTALLED = False


from polars.datatypes import (
Expand Down Expand Up @@ -54,7 +54,7 @@
)
from polars.internals import DataFrame, LazyFrame, Series, col

if HYPOTHIS_INSTALLED:
if HYPOTHESIS_INSTALLED:
# TODO: increase the number of iterations during CI checkins?
# https://hypothesis.readthedocs.io/en/latest/settings.html#settings-profiles
settings.register_profile(name="polars.default", max_examples=100, print_blob=True)
Expand Down Expand Up @@ -210,18 +210,22 @@ def _assert_series_inner(
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

if len(left) == len(right) == 0:
pass # empty series with same name/dtype are equal
elif check_exact:
if (left != right).sum() != 0:
# create mask of which (if any) values are unequal
unequal = left != right

# assert exact, or with tolerance
if unequal.any():
if check_exact:
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance, but only to the known-unequal matches
left, right = left.filter(unequal), right.filter(unequal)
if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)


def raise_assert_detail(
Expand Down Expand Up @@ -285,7 +289,7 @@ def is_categorical_dtype(data_type: Any) -> bool:
)


if HYPOTHIS_INSTALLED:
if HYPOTHESIS_INSTALLED:

def between(draw: Callable, type_: type, min_: Any, max_: Any) -> Any:
"""
Expand Down
18 changes: 17 additions & 1 deletion py-polars/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,18 @@ def test_assert_frame_equal_column_mismatch_order() -> None:
assert_frame_equal(df1, df2, check_column_names=False)


def test_assert_series_equal_int_overflow() -> None:
# internally may call 'abs' if not check_exact, which can overflow on signed int
s0 = pl.Series([-128], dtype=pl.Int8)
s1 = pl.Series([0, -128], dtype=pl.Int8)
s2 = pl.Series([1, -128], dtype=pl.Int8)

for check_exact in (True, False):
assert_series_equal(s0, s0, check_exact=check_exact)
with pytest.raises(AssertionError):
assert_series_equal(s1, s2, check_exact=check_exact)


@given(df=dataframes(), lf=dataframes(lazy=True), srs=series())
@settings(max_examples=10)
def test_strategy_classes(df: pl.DataFrame, lf: pl.LazyFrame, srs: pl.Series) -> None:
Expand Down Expand Up @@ -167,13 +179,17 @@ def test_strategy_frame_columns(lf: pl.LazyFrame) -> None:
assert lf.columns == ["a", "b", "c", "d"]
df = lf.collect()

# uint8 cols
# confirm uint cols bounds
uint8_max = (2**8) - 1
assert df["a"].min() >= 0
assert df["b"].min() >= 0
assert df["a"].max() <= uint8_max
assert df["b"].max() <= uint8_max

# confirm uint cols uniqueness
assert df["a"].is_unique().all()
assert df["b"].is_unique().all()

# boolean col
assert all(isinstance(v, bool) for v in df["c"].to_list())

Expand Down

0 comments on commit e3d1d52

Please sign in to comment.