From ecfaa8cfdb2ac351b8f6f588318acd08ae27c165 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 13:26:23 +0100 Subject: [PATCH 1/2] numpy sum --- tests/test_stats.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 91eb374..ca08d0f 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,5 +96,10 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - stats.sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile - benchmark(stats.sum, arr, axis=axis) + def sum(arr, axis): + if hasattr(arr, sum): + return arr.sum(axis=axis) + return np.sum(arr, axis=axis) + + sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile + benchmark(sum, arr, axis=axis) From 5e497b3867dc7b1dbd7d240ffd70a6d1378bd035 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 18 Feb 2025 13:29:36 +0100 Subject: [PATCH 2/2] types --- tests/test_stats.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index ca08d0f..818c8cf 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -96,10 +96,10 @@ def test_sum_benchmark( except NotImplementedError: pytest.skip("random_array not implemented for dtype") - def sum(arr, axis): - if hasattr(arr, sum): + def sum(arr: Array[Any], axis: int | None) -> Array[Any]: # noqa: A001 + if hasattr(arr, "sum"): return arr.sum(axis=axis) - return np.sum(arr, axis=axis) + return np.sum(arr, axis=axis) # type: ignore[arg-type] - sum(arr, axis=axis) # type: ignore[arg-type] # warmup: numba compile + sum(arr, axis=axis) # warmup: numba compile benchmark(sum, arr, axis=axis)