diff --git a/tests/test_stats.py b/tests/test_stats.py index 91eb374..818c8cf 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,5 +96,10 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile - benchmark(stats.sum, arr, axis=axis) + def sum(arr: Array[Any], axis: int | None) -> Array[Any]: # noqa: A001 + if hasattr(arr, "sum"): + return arr.sum(axis=axis) + return np.sum(arr, axis=axis) # type: ignore[arg-type] + + sum(arr, axis=axis) # warmup: numba compile + benchmark(sum, arr, axis=axis)