Skip to content

Commit

Permalink
Add testing.assert_frame_equal (#2181)
Browse files Browse the repository at this point in the history
* Add testing.assert_frame_equal

Closes #1167
  • Loading branch information
zundertj committed Dec 28, 2021
1 parent f6d7a18 commit 2894ba5
Show file tree
Hide file tree
Showing 4 changed files with 230 additions and 51 deletions.
3 changes: 3 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def version() -> str:
# this is only useful for documentation
warnings.warn("polars binary missing!")

import polars.testing as testing
from polars.cfg import ( # flake8: noqa. We do not export in __all__
Config,
toggle_string_cache,
Expand Down Expand Up @@ -202,6 +203,8 @@ def version() -> str:
"from_records",
"from_arrow",
"from_pandas",
# testing
"testing",
]

__version__ = version()
Expand Down
137 changes: 127 additions & 10 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Utf8,
dtype_to_py_type,
)
from polars.internals import Series
from polars.internals import DataFrame, Series

_NUMERIC_COL_TYPES = (
Int16,
Expand All @@ -31,6 +31,76 @@
)


def assert_frame_equal(
left: DataFrame,
right: DataFrame,
check_dtype: bool = True,
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Parameters
----------
left
the dataframe to compare
right
the dataframe to compare with
check_dtype
if True, data types need to match exactly
check_exact
if False, test if values are within tolerance of each other (see `rtol` & `atol`)
rtol
relative tolerance for inexact checking. Fraction of values in `right`
atol
absolute tolerance for inexact checking.
Returns
-------
Examples
--------
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> pl.testing.assert_frame_equal(df1, df2) # doctest: +SKIP
"""

obj = "DataFrame"
check_column_order = True

if not (isinstance(left, DataFrame) and isinstance(right, DataFrame)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))

if left.shape[0] != right.shape[0]:
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape)

# this assumes we want it in the same order
union_cols = list(set(left.columns).union(set(right.columns)))
for c in union_cols:
if c not in right.columns:
raise AssertionError(
f"column {c} in left dataframe, but not in right dataframe"
)
if c not in left.columns:
raise AssertionError(
f"column {c} in right dataframe, but not in left dataframe"
)

if check_column_order:
if left.columns != right.columns:
raise AssertionError("Columns are not in same order")

# this does not assume a particular order
for col in left.columns:
_assert_series_inner(
left[col], right[col], check_dtype, check_exact, atol, rtol, obj
)


def assert_series_equal(
left: Series,
right: Series,
Expand All @@ -40,27 +110,74 @@ def assert_series_equal(
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Parameters
----------
left
the series to compare
right
the series to compare with
check_dtype
if True, data types need to match exactly
check_names
if True, names need to match
check_exact
if False, test if values are within tolerance of each other (see `rtol` & `atol`)
rtol
relative tolerance for inexact checking. Fraction of values in `right`
atol
absolute tolerance for inexact checking.
Returns
-------
Examples
--------
>>> s1 = pl.Series([1, 2, 3])
>>> s2 = pl.Series([2, 3, 4])
>>> pl.testing.assert_series_equal(s1, s2) # doctest: +SKIP
"""
obj = "Series"
try:
can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
except NotImplementedError:
can_be_subtracted = False

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
if not (isinstance(left, Series) and isinstance(right, Series)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)

if check_dtype:
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

if check_names:
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_assert_series_inner(left, right, check_dtype, check_exact, atol, rtol, obj)


def _assert_series_inner(
left: Series,
right: Series,
check_dtype: bool,
check_exact: bool,
atol: float,
rtol: float,
obj: str,
) -> None:
"""
Compares Series dtype + values
"""
try:
can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
except NotImplementedError:
can_be_subtracted = False

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean

if check_dtype:
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

if check_exact:
if (left != right).sum() != 0:
raise_assert_detail(
Expand Down
41 changes: 0 additions & 41 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,47 +1174,6 @@ def test_dt_datetimes() -> None:
)


def test_compare_series_value_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.Series([2, 3, 4])
with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_type_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.DataFrame({"col1": [2, 3, 4]})
with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"):
testing.assert_series_equal(srs1, srs2) # type: ignore

srs3 = pl.Series([1.0, 2.0, 3.0])
with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"):
testing.assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_shape_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
):
testing.assert_series_equal(srs1, srs2, check_exact=True)


def test_reshape() -> None:
s = pl.Series("a", [1, 2, 3, 4])
out = s.reshape((-1, 2))
Expand Down
100 changes: 100 additions & 0 deletions py-polars/tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pytest

import polars as pl


def test_compare_series_value_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.Series([2, 3, 4])
with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"):
pl.testing.assert_series_equal(srs1, srs2)


def test_compare_series_nulls_are_equal() -> None:
srs1 = pl.Series([1, 2, None])
srs2 = pl.Series([1, 2, None])
pl.testing.assert_series_equal(srs1, srs2)


def test_compare_series_value_mismatch_string() -> None:
srs1 = pl.Series(["hello", "no"])
srs2 = pl.Series(["hello", "yes"])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
):
pl.testing.assert_series_equal(srs1, srs2)


def test_compare_series_type_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.DataFrame({"col1": [2, 3, 4]})
with pytest.raises(AssertionError, match="Series are different\n\nType mismatch"):
pl.testing.assert_series_equal(srs1, srs2) # type: ignore

srs3 = pl.Series([1.0, 2.0, 3.0])
with pytest.raises(AssertionError, match="Series are different\n\nDtype mismatch"):
pl.testing.assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"):
pl.testing.assert_series_equal(srs1, srs2)


def test_compare_series_shape_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"):
pl.testing.assert_series_equal(srs1, srs2)


def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
):
pl.testing.assert_series_equal(srs1, srs2, check_exact=True)


def test_assert_frame_equal_pass() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2]})
pl.testing.assert_frame_equal(df1, df2)


def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(AssertionError):
pl.testing.assert_frame_equal(df1, srs1) # type: ignore


def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(AssertionError):
pl.testing.assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"b": [1, 2]})
with pytest.raises(AssertionError):
pl.testing.assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch2() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(AssertionError):
pl.testing.assert_frame_equal(df1, df2)


def test_assert_frame_equal_column_mismatch_order() -> None:
df1 = pl.DataFrame({"b": [3, 4], "a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2], "b": [3, 4]})
with pytest.raises(AssertionError):
pl.testing.assert_frame_equal(df1, df2)

0 comments on commit 2894ba5

Please sign in to comment.