Skip to content

Commit

Permalink
perf(python): never import hypothesis in user code (#5282)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 21, 2022
1 parent 760fe0e commit 33da4a8
Show file tree
Hide file tree
Showing 14 changed files with 328 additions and 296 deletions.
8 changes: 4 additions & 4 deletions py-polars/docs/source/reference/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ testing strategies and strategy helpers:
.. autosummary::
:toctree: api/

testing.dataframes
testing.series
testing._parametric.dataframes
testing._parametric.series

Strategy helpers
~~~~~~~~~~~~~~~~

.. autosummary::
:toctree: api/

testing.column
testing.columns
testing._parametric.column
testing._parametric.columns
11 changes: 11 additions & 0 deletions py-polars/polars/testing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from polars.testing.asserts import (
assert_frame_equal,
assert_frame_equal_local_categoricals,
assert_series_equal,
)

__all__ = [
"assert_series_equal",
"assert_frame_equal",
"assert_frame_equal_local_categoricals",
]
287 changes: 12 additions & 275 deletions py-polars/polars/testing.py → py-polars/polars/testing/_parametric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import warnings
from dataclasses import dataclass
from datetime import datetime, timedelta
from functools import reduce
from math import isfinite
from typing import Any, Sequence

from polars.testing.asserts import is_categorical_dtype

try:
from hypothesis import settings
from hypothesis.errors import InvalidArgument, NonInteractiveExampleWarning
Expand Down Expand Up @@ -55,7 +56,6 @@
UInt32,
UInt64,
Utf8,
dtype_to_py_type,
is_polars_dtype,
py_type_to_dtype,
)
Expand Down Expand Up @@ -85,253 +85,6 @@
MAX_DATA_SIZE = 10
MAX_COLS = 8


def assert_frame_equal(
left: pli.DataFrame | pli.LazyFrame,
right: pli.DataFrame | pli.LazyFrame,
check_dtype: bool = True,
check_exact: bool = False,
check_column_names: bool = True,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
nans_compare_equal: bool = True,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Parameters
----------
left
the dataframe to compare.
right
the dataframe to compare with.
check_dtype
if True, data types need to match exactly.
check_exact
if False, test if values are within tolerance of each other
(see `rtol` & `atol`).
check_column_names
if True, dataframes must have the same column names in the same order.
rtol
relative tolerance for inexact checking. Fraction of values in `right`.
atol
absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Examples
--------
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> pl.testing.assert_frame_equal(df1, df2) # doctest: +SKIP
"""
if isinstance(left, pli.LazyFrame) and isinstance(right, pli.LazyFrame):
left, right = left.collect(), right.collect()
obj = "pli.LazyFrame"
else:
obj = "pli.DataFrame"

if not (isinstance(left, pli.DataFrame) and isinstance(right, pli.DataFrame)):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))
elif left.shape[0] != right.shape[0]:
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape)

# this assumes we want it in the same order
union_cols = list(set(left.columns).union(set(right.columns)))
for c in union_cols:
if c not in right.columns:
raise AssertionError(f"column {c} in left frame, but not in right")
if c not in left.columns:
raise AssertionError(f"column {c} in right frame, but not in left")

if check_column_names:
if left.columns != right.columns:
raise AssertionError("Columns are not in the same order")

# this does not assume a particular order
for c in left.columns:
_assert_series_inner(
left[c], # type: ignore[arg-type, index]
right[c], # type: ignore[arg-type, index]
check_dtype,
check_exact,
nans_compare_equal,
atol,
rtol,
obj,
)


def assert_series_equal(
left: pli.Series,
right: pli.Series,
check_dtype: bool = True,
check_names: bool = True,
check_exact: bool = False,
rtol: float = 1.0e-5,
atol: float = 1.0e-8,
nans_compare_equal: bool = True,
) -> None:
"""
Raise detailed AssertionError if `left` does not equal `right`.
Parameters
----------
left
the series to compare.
right
the series to compare with.
check_dtype
if True, data types need to match exactly.
check_names
if True, names need to match.
check_exact
if False, test if values are within tolerance of each other
(see `rtol` & `atol`).
rtol
relative tolerance for inexact checking. Fraction of values in `right`.
atol
absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Examples
--------
>>> s1 = pl.Series([1, 2, 3])
>>> s2 = pl.Series([2, 3, 4])
>>> pl.testing.assert_series_equal(s1, s2) # doctest: +SKIP
"""
obj = "Series"

if not (
isinstance(left, pli.Series) # type: ignore[redundant-expr]
and isinstance(right, pli.Series)
):
raise_assert_detail(obj, "Type mismatch", type(left), type(right))

if left.shape != right.shape:
raise_assert_detail(obj, "Shape mismatch", left.shape, right.shape)

if check_names:
if left.name != right.name:
raise_assert_detail(obj, "Name mismatch", left.name, right.name)

_assert_series_inner(
left, right, check_dtype, check_exact, nans_compare_equal, atol, rtol, obj
)


def _assert_series_inner(
left: pli.Series,
right: pli.Series,
check_dtype: bool,
check_exact: bool,
nans_compare_equal: bool,
atol: float,
rtol: float,
obj: str,
) -> None:
"""Compare Series dtype + values."""
try:
can_be_subtracted = hasattr(dtype_to_py_type(left.dtype), "__sub__")
except NotImplementedError:
can_be_subtracted = False

check_exact = check_exact or not can_be_subtracted or left.dtype == Boolean
if check_dtype:
if left.dtype != right.dtype:
raise_assert_detail(obj, "Dtype mismatch", left.dtype, right.dtype)

# create mask of which (if any) values are unequal
unequal = left != right
if unequal.any() and nans_compare_equal and left.dtype in (Float32, Float64):
# handle NaN values (which compare unequal to themselves)
unequal = unequal & ~(
(left.is_nan() & right.is_nan()).fill_null(pli.lit(False))
)

# assert exact, or with tolerance
if unequal.any():
if check_exact:
raise_assert_detail(
obj, "Exact value mismatch", left=list(left), right=list(right)
)
else:
# apply check with tolerance, but only to the known-unequal matches
left, right = left.filter(unequal), right.filter(unequal)
if (((left - right).abs() > (atol + rtol * right.abs())).sum() != 0) or (
(left.is_null() != right.is_null()).any()
):
raise_assert_detail(
obj, "Value mismatch", left=list(left), right=list(right)
)


def raise_assert_detail(
obj: str,
message: str,
left: Any,
right: Any,
) -> None:
__tracebackhide__ = True

msg = f"""{obj} are different
{message}"""

msg += f"""
[left]: {left}
[right]: {right}"""

raise AssertionError(msg)


def _getattr_multi(obj: object, op: str) -> Any:
"""
Allow `op` to be multiple layers deep.
For example, op="str.lengths" will mean we first get the attribute "str", and then
the attribute "lengths".
"""
op_list = op.split(".")
return reduce(lambda o, m: getattr(o, m), op_list, obj)


def verify_series_and_expr_api(
input: pli.Series, expected: pli.Series | None, op: str, *args: Any, **kwargs: Any
) -> None:
"""
Test element-wise functions for both the series and expressions API.
Examples
--------
>>> s = pl.Series([1, 3, 2])
>>> expected = pl.Series([1, 2, 3])
>>> verify_series_and_expr_api(s, expected, "sort")
"""
expr = _getattr_multi(pli.col("*"), op)(*args, **kwargs)
result_expr = input.to_frame().select(expr)[:, 0]
result_series = _getattr_multi(input, op)(*args, **kwargs)
if expected is None:
assert_series_equal(result_series, result_expr)
else:
assert_series_equal(result_expr, expected)
assert_series_equal(result_series, expected)


def is_categorical_dtype(data_type: Any) -> bool:
"""Check if the input is a polars Categorical dtype."""
return (
type(data_type) is type
and issubclass(data_type, Categorical)
or isinstance(data_type, Categorical)
)


if HYPOTHESIS_INSTALLED:
# =====================================================================
# Polars-specific 'hypothesis' strategies and helper functions
Expand Down Expand Up @@ -407,9 +160,13 @@ class column:
Examples
--------
>>> from hypothesis.strategies import sampled_from
>>> pl.testing.column(name="unique_small_ints", dtype=pl.UInt8, unique=True)
>>> pl.testing._parametric.column(
... name="unique_small_ints", dtype=pl.UInt8, unique=True
... )
column(name='unique_small_ints', dtype=<class 'polars.datatypes.UInt8'>, strategy=None, null_probability=None, unique=True)
>>> pl.testing.column(name="ccy", strategy=sampled_from(["GBP", "EUR", "JPY"]))
>>> pl.testing._parametric.column(
... name="ccy", strategy=sampled_from(["GBP", "EUR", "JPY"])
... )
column(name='ccy', dtype=<class 'polars.datatypes.Utf8'>, strategy=sampled_from(['GBP', 'EUR', 'JPY']), null_probability=None, unique=False)
""" # noqa: E501
Expand Down Expand Up @@ -498,7 +255,7 @@ def columns(
Examples
--------
>>> from polars.testing import columns
>>> from polars.testing._parametric import columns
>>> from string import punctuation
>>>
>>> def test_special_char_colname_init() -> None:
Expand All @@ -507,7 +264,7 @@ def columns(
... assert len(cols) == len(df.columns)
... assert 0 == len(df.rows())
...
>>> from polars.testing import columns
>>> from polars.testing._parametric import columns
>>> from hypothesis import given
>>>
>>> @given(dataframes(columns(["x", "y", "z"], unique=True)))
Expand Down Expand Up @@ -606,7 +363,7 @@ def series(
Examples
--------
>>> from polars.testing import series
>>> from polars.testing._parametric import series
>>> from hypothesis import given
>>>
>>> @given(df=series())
Expand Down Expand Up @@ -774,7 +531,7 @@ def dataframes(
generate. Note: in actual use the strategy is applied as a test decorator, not
used standalone.
>>> from polars.testing import column, columns, dataframes
>>> from polars.testing._parametric import column, columns, dataframes
>>> from hypothesis import given
>>>
>>> # generate arbitrary DataFrames
Expand Down Expand Up @@ -888,23 +645,3 @@ def draw_frames(draw: DrawFn) -> pli.DataFrame | pli.LazyFrame:
return df.lazy() if lazy else df

return draw_frames()


def assert_frame_equal_local_categoricals(
df_a: pli.DataFrame, df_b: pli.DataFrame
) -> None:

for ((a_name, a_value), (b_name, b_value)) in zip(
df_a.schema.items(), df_b.schema.items()
):
if a_name != b_name:
print(f"{a_name} != {b_name}")
raise AssertionError
if a_value != b_value:
print(f"{a_value} != {b_value}")
raise AssertionError

cat_to_str = pli.col(Categorical).cast(str)
assert df_a.with_column(cat_to_str).frame_equal(df_b.with_column(cat_to_str))
cat_to_phys = pli.col(Categorical).to_physical()
assert df_a.with_column(cat_to_phys).frame_equal(df_b.with_column(cat_to_phys))

0 comments on commit 33da4a8

Please sign in to comment.