diff --git a/tests/test_stats.py b/tests/test_stats.py index 818c8cf..91eb374 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,10 +96,5 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - def sum(arr: Array[Any], axis: int | None) -> Array[Any]: # noqa: A001 - if hasattr(arr, "sum"): - return arr.sum(axis=axis) - return np.sum(arr, axis=axis) # type: ignore[arg-type] - - sum(arr, axis=axis) # warmup: numba compile - benchmark(sum, arr, axis=axis) + stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile + benchmark(stats.sum, arr, axis=axis)