Skip to content

Commit

Permalink
improve perf of testing functions, run rust code, not python
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Nov 28, 2021
1 parent 031b5e9 commit 8f198f4
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions py-polars/polars/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ def assert_series_equal(
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
obj: str = "Series",
) -> None:
if obj == "Series":
if not (isinstance(left, Series) and isinstance(right, Series)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
obj = "Series"
can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
if not (isinstance(left, Series) and isinstance(right, Series)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)
Expand All @@ -56,14 +57,13 @@ def assert_series_equal(
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
if check_exact or not _can_be_subtracted or dtype_to_py_type(left.dtype) == bool:
if any((left != right).to_list()):
if check_exact:
if (left != right).sum() != 0:
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
if any((left - right).abs() > (atol + rtol * right.abs())):
if ((left - right).abs() > (atol + rtol * right.abs())).sum() != 0:
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)
Expand All @@ -74,7 +74,6 @@ def raise_assert_detail(
message: str,
left: Any,
right: Any,
diff: Series = None,
) -> None:
__tracebackhide__ = True

Expand Down

0 comments on commit 8f198f4

Please sign in to comment.