From c4357045028fd131c56a5dbeb6346f571c4c2905 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 12:05:29 +0200 Subject: [PATCH 01/12] Add precision test --- tests/test_stats.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index a33b1d3..00c7b17 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -90,7 +90,7 @@ def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTyp return dtype -@pytest.fixture(scope="session", params=[np.float32, np.float64, None]) +@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int64, None]) def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: return cast("type[DTypeOut] | None", request.param) @@ -98,6 +98,8 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: @pytest.fixture def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) + if np.dtype(dtype_in).kind == "f": + np_arr /= 3 np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() @@ -158,7 +160,10 @@ def test_sum( assert sum_.dtype == dtype_in expected = np.sum(np_arr, axis=axis, dtype=dtype_arg) - np.testing.assert_array_equal(sum_, expected) + if np.dtype(dtype_arg).kind == np.dtype(dtype_in).kind and np.dtype(dtype_arg).itemsize >= np.dtype(dtype_in).itemsize: + np.testing.assert_array_equal(sum_, expected) + else: + np.testing.assert_array_almost_equal(sum_, expected) @pytest.mark.parametrize( From 008e6ea98ed571c52d6452cfb26c91920feaaac6 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 13:15:00 +0200 Subject: [PATCH 02/12] =?UTF-8?q?fix=20float32=E2=86=92float64?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/fast_array_utils/stats/_sum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index a0a14f4..8a8132e 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -92,7 +92,7 @@ def _sum_dask( rv = da.reduction( x, - sum_dask_inner, # type: ignore[arg-type] + partial(sum_dask_inner, dtype=dtype), # type: ignore[arg-type] partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, From 29333a8f3dd1cae56640cecb3195ee753e5d7cba Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 13:28:40 +0200 Subject: [PATCH 03/12] only for dask --- tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 00c7b17..2784166 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -160,7 +160,7 @@ def test_sum( assert sum_.dtype == dtype_in expected = np.sum(np_arr, axis=axis, dtype=dtype_arg) - if np.dtype(dtype_arg).kind == np.dtype(dtype_in).kind and np.dtype(dtype_arg).itemsize >= np.dtype(dtype_in).itemsize: + if array_type.cls is not types.DaskArray: np.testing.assert_array_equal(sum_, expected) else: np.testing.assert_array_almost_equal(sum_, expected) From f5f270ffc0fc7fa9c1482e6f4e26a804a4f9b8e4 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 13:35:29 +0200 Subject: [PATCH 04/12] simplify --- src/fast_array_utils/stats/_sum.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 8a8132e..f819265 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -66,8 +66,6 @@ def _sum_cs( if isinstance(x, types.CSMatrix): x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) - if axis is None: - return cast("np.number[Any]", x.data.sum(dtype=dtype)) return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype)) From a4c4111a0df32d50f76fe8f8111c8abf5dd3dc44 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 13:57:27 +0200 Subject: [PATCH 05/12] revert --- src/fast_array_utils/stats/_sum.py | 2 +- tests/test_stats.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index f819265..5a39148 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -90,7 +90,7 @@ def _sum_dask( rv = da.reduction( x, - partial(sum_dask_inner, dtype=dtype), # type: ignore[arg-type] + sum_dask_inner, # type: ignore[arg-type] partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, diff --git a/tests/test_stats.py b/tests/test_stats.py index 2784166..9b65163 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -99,7 +99,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) if np.dtype(dtype_in).kind == "f": - np_arr /= 3 + np_arr /= 3 # type: ignore[misc] np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() From 4c60e1e018e953644432c6721a8ed74357c3c780 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 14:06:51 +0200 Subject: [PATCH 06/12] oof --- src/fast_array_utils/stats/_sum.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 5a39148..7a2af93 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -88,9 +88,10 @@ def _sum_dask( # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype + # TODO(flying-sheep): chunk=sum_dask_inner fixes mean_var(<1d array dtype=float32>) # noqa: TD003 rv = da.reduction( x, - sum_dask_inner, # type: ignore[arg-type] + partial(sum_dask_inner, dtype=dtype), # type: ignore[arg-type] partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, From eb6526bbcc69b1eec49b185ddd5931055268d94e Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 14:12:32 +0200 Subject: [PATCH 07/12] less broken --- tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 9b65163..07b8998 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -99,7 +99,7 @@ def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None: def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in)) if np.dtype(dtype_in).kind == "f": - np_arr /= 3 # type: ignore[misc] + np_arr /= 4 # type: ignore[misc] np_arr.flags.writeable = False if ndim == 1: np_arr = np_arr.flatten() From 5c802fd62e96b234849d7bc633ecc8ef5d4da30c Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 13 Oct 2025 14:24:34 +0200 Subject: [PATCH 08/12] huh --- src/fast_array_utils/stats/_sum.py | 1 - tests/test_stats.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 7a2af93..f819265 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -88,7 +88,6 @@ def _sum_dask( # Explicitly use numpy result dtype (e.g. `NDArray[bool].sum().dtype == int64`) dtype = np.zeros(1, dtype=x.dtype).sum().dtype - # TODO(flying-sheep): chunk=sum_dask_inner fixes mean_var(<1d array dtype=float32>) # noqa: TD003 rv = da.reduction( x, partial(sum_dask_inner, dtype=dtype), # type: ignore[arg-type] diff --git a/tests/test_stats.py b/tests/test_stats.py index 07b8998..e45247c 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -166,6 +166,18 @@ def test_sum( np.testing.assert_array_almost_equal(sum_, expected) +@pytest.mark.array_type(skip=ATS_SPARSE_DS) +def test_sum_to_int(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None) -> None: + rng = np.random.default_rng(0) + np_arr = rng.random((100, 100)) + arr = array_type(np_arr) + sum_ = stats.sum(arr, axis=axis, dtype=np.int64) + if axis is None: + assert sum_ == np.int64(0) + else: + np.testing.assert_array_equal(sum_, np.zeros(arr.shape[axis], dtype=np.int64)) + + @pytest.mark.parametrize( "data", [ From 25b87c87eae9afa5bcf6ebab8d9de70bf5dcefce Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 14 Oct 2025 09:59:57 +0200 Subject: [PATCH 09/12] all fixed --- .pre-commit-config.yaml | 3 ++- src/fast_array_utils/stats/_sum.py | 11 +++++++---- tests/test_stats.py | 11 +++-------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b7495c..5fca7a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,8 @@ repos: rev: v1.18.2 hooks: - id: mypy - args: [--config-file=pyproject.toml] + args: [--config-file=pyproject.toml, .] + pass_filenames: false additional_dependencies: - pytest - pytest-codspeed!=4.0.0 # https://github.com/CodSpeedHQ/pytest-codspeed/pull/84 diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index f819265..abd6043 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -63,10 +63,13 @@ def _sum_cs( del keep_cupy_as_array import scipy.sparse as sp - if isinstance(x, types.CSMatrix): - x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x) + dtype = np.dtype(dtype) if dtype is not None else None + # convert to array so dimensions collapse as expected + x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype) - return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype)) + # TODO(flying-sheep): use `dtype=dtype` here when of above once scipy fixes this + # https://github.com/scipy/scipy/issues/23768 + return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis)) @sum_.register(types.DaskArray) @@ -90,7 +93,7 @@ def _sum_dask( rv = da.reduction( x, - partial(sum_dask_inner, dtype=dtype), # type: ignore[arg-type] + partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType] axis=axis, dtype=dtype, diff --git a/tests/test_stats.py b/tests/test_stats.py index e45247c..649a507 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -160,10 +160,7 @@ def test_sum( assert sum_.dtype == dtype_in expected = np.sum(np_arr, axis=axis, dtype=dtype_arg) - if array_type.cls is not types.DaskArray: - np.testing.assert_array_equal(sum_, expected) - else: - np.testing.assert_array_almost_equal(sum_, expected) + np.testing.assert_array_equal(sum_, expected) @pytest.mark.array_type(skip=ATS_SPARSE_DS) @@ -172,10 +169,8 @@ def test_sum_to_int(array_type: ArrayType[CpuArray | GpuArray | DiskArray | type np_arr = rng.random((100, 100)) arr = array_type(np_arr) sum_ = stats.sum(arr, axis=axis, dtype=np.int64) - if axis is None: - assert sum_ == np.int64(0) - else: - np.testing.assert_array_equal(sum_, np.zeros(arr.shape[axis], dtype=np.int64)) + expected = np.zeros(() if axis is None else arr.shape[axis], dtype=np.int64) + np.testing.assert_array_equal(sum_, expected) @pytest.mark.parametrize( From e67eeffc66cd423050e0327432410b38389e3527 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 14 Oct 2025 12:57:46 +0200 Subject: [PATCH 10/12] fix coercion --- tests/test_stats.py | 46 +++++++++++++--------- typings/cupy/_core/core.pyi | 3 +- typings/cupyx/scipy/sparse/_compressed.pyi | 1 + 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index 649a507..9a2ed90 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -106,6 +106,26 @@ def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]: return np_arr +def to_np_dense_checked( + stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray +) -> NDArray[DTypeOut] | np.number[Any]: + match axis, arr: + case _, types.DaskArray(): + assert isinstance(stat, types.DaskArray), type(stat) + stat = stat.compute() # type: ignore[assignment] + return to_np_dense_checked(stat, axis, arr.compute()) + case None, _: + assert isinstance(stat, np.floating | np.integer), type(stat) + case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix() | types.CupyCOOMatrix(): + assert isinstance(stat, types.CupyArray), type(stat) + return to_np_dense_checked(stat.get(), axis, arr.get()) + case 0 | 1, _: + assert isinstance(stat, np.ndarray), type(stat) + case _: + pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(stat)}") + return stat + + @pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix}) @pytest.mark.parametrize("func", STAT_FUNCS) @pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"]) @@ -129,26 +149,13 @@ def test_sum( axis: Literal[0, 1] | None, np_arr: NDArray[DTypeIn], ) -> None: + if np.dtype(dtype_arg).kind in "iu" and (array_type.flags & Flags.Gpu) and (array_type.flags & Flags.Sparse): + pytest.skip("GPU sparse matrices don’t support int dtypes") arr = array_type(np_arr.copy()) assert arr.dtype == dtype_in sum_ = stats.sum(arr, axis=axis, dtype=dtype_arg) - - match axis, arr: - case _, types.DaskArray(): - assert isinstance(sum_, types.DaskArray), type(sum_) - sum_ = sum_.compute() # type: ignore[assignment] - if isinstance(sum_, types.CupyArray): - sum_ = sum_.get() - case None, _: - assert isinstance(sum_, np.floating | np.integer), type(sum_) - case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix(): - assert isinstance(sum_, types.CupyArray), type(sum_) - sum_ = sum_.get() - case 0 | 1, _: - assert isinstance(sum_, np.ndarray), type(sum_) - case _: - pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(sum_)}") + sum_ = to_np_dense_checked(sum_, axis, arr) # type: ignore[arg-type] assert sum_.shape == () if axis is None else arr.shape[axis], (sum_.shape, arr.shape) @@ -163,12 +170,15 @@ def test_sum( np.testing.assert_array_equal(sum_, expected) -@pytest.mark.array_type(skip=ATS_SPARSE_DS) -def test_sum_to_int(array_type: ArrayType[CpuArray | GpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None) -> None: +@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Gpu}) +def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None) -> None: rng = np.random.default_rng(0) np_arr = rng.random((100, 100)) arr = array_type(np_arr) + sum_ = stats.sum(arr, axis=axis, dtype=np.int64) + sum_ = to_np_dense_checked(sum_, axis, arr) + expected = np.zeros(() if axis is None else arr.shape[axis], dtype=np.int64) np.testing.assert_array_equal(sum_, expected) diff --git a/typings/cupy/_core/core.pyi b/typings/cupy/_core/core.pyi index f8d459e..e1a5231 100644 --- a/typings/cupy/_core/core.pyi +++ b/typings/cupy/_core/core.pyi @@ -5,7 +5,7 @@ from typing import Any, Literal, Self, overload import numpy as np from cupy.cuda import Stream from numpy._core.multiarray import flagsobj -from numpy.typing import NDArray +from numpy.typing import DTypeLike, NDArray class ndarray: dtype: np.dtype[Any] @@ -41,6 +41,7 @@ class ndarray: def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ... @property def flat(self) -> _FlatIter: ... + def sum(self, axis: int | None = None, dtype: DTypeLike | None = None, out: ndarray | None = None, keepdims: bool = False) -> ndarray: ... class _FlatIter: def __next__(self) -> np.float32 | np.float64: ... diff --git a/typings/cupyx/scipy/sparse/_compressed.pyi b/typings/cupyx/scipy/sparse/_compressed.pyi index a53c190..b697183 100644 --- a/typings/cupyx/scipy/sparse/_compressed.pyi +++ b/typings/cupyx/scipy/sparse/_compressed.pyi @@ -20,3 +20,4 @@ class _compressed_sparse_matrix(spmatrix): # methods def power(self, n: int, dtype: DTypeLike | None = None) -> Self: ... + def sum(self, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, out: Self | None = None) -> ndarray: ... From 41840cc10e858dbaae1883fb468eaecb90cb5833 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 14 Oct 2025 13:11:12 +0200 Subject: [PATCH 11/12] fast path for 0d sum --- src/fast_array_utils/stats/_sum.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index abd6043..3c7d128 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -63,8 +63,12 @@ def _sum_cs( del keep_cupy_as_array import scipy.sparse as sp - dtype = np.dtype(dtype) if dtype is not None else None + if axis is None: + return cast("NDArray[Any] | np.number[Any]", x.data.sum(dtype=dtype)) + # convert to array so dimensions collapse as expected + if TYPE_CHECKING: + dtype = np.dtype(dtype) if dtype is not None else None x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype) # TODO(flying-sheep): use `dtype=dtype` here when of above once scipy fixes this From 7b08275d77fe862c5af7b49084ef01c1b3cfa440 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 14 Oct 2025 13:18:52 +0200 Subject: [PATCH 12/12] reformat --- src/fast_array_utils/stats/_sum.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/fast_array_utils/stats/_sum.py b/src/fast_array_utils/stats/_sum.py index 3c7d128..d3f09c8 100644 --- a/src/fast_array_utils/stats/_sum.py +++ b/src/fast_array_utils/stats/_sum.py @@ -63,16 +63,17 @@ def _sum_cs( del keep_cupy_as_array import scipy.sparse as sp + # TODO(flying-sheep): once scipy fixes this issue, instead of all this, + # just convert to sparse array, then `return x.sum(dtype=dtype)` + # https://github.com/scipy/scipy/issues/23768 + if axis is None: return cast("NDArray[Any] | np.number[Any]", x.data.sum(dtype=dtype)) + if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true + assert isinstance(dtype, np.dtype | type | None) # convert to array so dimensions collapse as expected - if TYPE_CHECKING: - dtype = np.dtype(dtype) if dtype is not None else None x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype) - - # TODO(flying-sheep): use `dtype=dtype` here when of above once scipy fixes this - # https://github.com/scipy/scipy/issues/23768 return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis))