Skip to content

Commit

Permalink
add assert_series_equal function and corresponding tests (#1828)
Browse files Browse the repository at this point in the history
  • Loading branch information
CloseChoice committed Nov 20, 2021
1 parent 2c6f811 commit 4dc472f
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 1 deletion.
2 changes: 1 addition & 1 deletion py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
var,
)

from . import cfg, convert, datatypes, eager, functions, io, lazy, string_cache
from . import cfg, convert, datatypes, eager, functions, io, lazy, string_cache, testing
from .cfg import *
from .convert import *
from .datatypes import *
Expand Down
98 changes: 98 additions & 0 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
try:
from typing import Any

from polars import Series
from polars.datatypes import (
_DTYPE_TO_PY_TYPE,
Boolean,
Float32,
Float64,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Utf8,
)

_DOCUMENTING = False
except ImportError:
_DOCUMENTING = True


_NUMERIC_COL_TYPES = (
Int16,
Int32,
Int64,
UInt16,
UInt32,
UInt64,
UInt8,
Utf8,
Float32,
Float64,
Boolean,
)


def assert_series_equal(
left: Series,
right: Series,
check_dtype: bool = True,
check_names: bool = True,
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
obj: str = "Series",
) -> None:
if obj == "Series":
if not (isinstance(left, Series) and isinstance(right, Series)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)

if check_dtype:
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

if check_names:
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_can_be_subtracted = hasattr(_DTYPE_TO_PY_TYPE[left.dtype], "__sub__")
if check_exact or not _can_be_subtracted:
if any((left != right).to_list()):
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
if any((left - right).abs() > (atol + rtol * right.abs())):
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)


def raise_assert_detail(
obj: str,
message: str,
left: Any,
right: Any,
diff: Series = None,
) -> None:
__tracebackhide__ = True

msg = f"""{obj} are different
{message}"""

msg += f"""
[left]: {left}
[right]: {right}"""

if diff is not None:
msg += f"\n[diff left]: {diff}"

raise AssertionError(msg)
31 changes: 31 additions & 0 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import polars as pl
from polars import testing


def create_series() -> pl.Series:
Expand Down Expand Up @@ -927,3 +928,33 @@ def test_dt_year_month_week_day_ordinal_day() -> None:
assert a.dt.weekday().to_list() == [0, 4, 1]
assert a.dt.day().to_list() == [19, 4, 20]
assert a.dt.ordinal_day().to_list() == [139, 278, 51]


def test_compare_series_value_mismatch() -> None:
srs1 = pl.Series([1, 2, 3])
srs2 = pl.Series([2, 3, 4])
with pytest.raises(AssertionError, match="Series are different\n\nValue mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nName mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_shape_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3, 4], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different\n\nShape mismatch"):
testing.assert_series_equal(srs1, srs2)


def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different\n\nExact value mismatch"
):
testing.assert_series_equal(srs1, srs2, check_exact=True)

0 comments on commit 4dc472f

Please sign in to comment.