Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ repos:
rev: v1.18.2
hooks:
- id: mypy
args: [--config-file=pyproject.toml]
args: [--config-file=pyproject.toml, .]
pass_filenames: false
additional_dependencies:
- pytest
- pytest-codspeed!=4.0.0 # https://github.com/CodSpeedHQ/pytest-codspeed/pull/84
Expand Down
16 changes: 11 additions & 5 deletions src/fast_array_utils/stats/_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,18 @@ def _sum_cs(
del keep_cupy_as_array
import scipy.sparse as sp

if isinstance(x, types.CSMatrix):
x = sp.csr_array(x) if x.format == "csr" else sp.csc_array(x)
# TODO(flying-sheep): once scipy fixes this issue, instead of all this,
# just convert to sparse array, then `return x.sum(dtype=dtype)`
# https://github.com/scipy/scipy/issues/23768

if axis is None:
return cast("np.number[Any]", x.data.sum(dtype=dtype))
return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis, dtype=dtype))
return cast("NDArray[Any] | np.number[Any]", x.data.sum(dtype=dtype))

if TYPE_CHECKING: # scipy-stubs thinks e.g. "int64" is invalid, which isn’t true
assert isinstance(dtype, np.dtype | type | None)
# convert to array so dimensions collapse as expected
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, dtype=dtype)
return cast("NDArray[Any] | np.number[Any]", x.sum(axis=axis))


@sum_.register(types.DaskArray)
Expand All @@ -92,7 +98,7 @@ def _sum_dask(

rv = da.reduction(
x,
sum_dask_inner, # type: ignore[arg-type]
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
partial(sum_dask_inner, dtype=dtype), # pyright: ignore[reportArgumentType]
axis=axis,
dtype=dtype,
Expand Down
56 changes: 39 additions & 17 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,42 @@ def dtype_in(request: pytest.FixtureRequest, array_type: ArrayType) -> type[DTyp
return dtype


@pytest.fixture(scope="session", params=[np.float32, np.float64, None])
@pytest.fixture(scope="session", params=[np.float32, np.float64, np.int64, None])
def dtype_arg(request: pytest.FixtureRequest) -> type[DTypeOut] | None:
return cast("type[DTypeOut] | None", request.param)


@pytest.fixture
def np_arr(dtype_in: type[DTypeIn], ndim: Literal[1, 2]) -> NDArray[DTypeIn]:
np_arr = cast("NDArray[DTypeIn]", np.array([[1, 0], [3, 0], [5, 6]], dtype=dtype_in))
if np.dtype(dtype_in).kind == "f":
np_arr /= 4 # type: ignore[misc]
np_arr.flags.writeable = False
if ndim == 1:
np_arr = np_arr.flatten()
return np_arr


def to_np_dense_checked(
stat: NDArray[DTypeOut] | np.number[Any] | types.DaskArray, axis: Literal[0, 1] | None, arr: CpuArray | GpuArray | DiskArray | types.DaskArray
) -> NDArray[DTypeOut] | np.number[Any]:
match axis, arr:
case _, types.DaskArray():
assert isinstance(stat, types.DaskArray), type(stat)
stat = stat.compute() # type: ignore[assignment]
return to_np_dense_checked(stat, axis, arr.compute())
case None, _:
assert isinstance(stat, np.floating | np.integer), type(stat)
case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix() | types.CupyCOOMatrix():
assert isinstance(stat, types.CupyArray), type(stat)
return to_np_dense_checked(stat.get(), axis, arr.get())
case 0 | 1, _:
assert isinstance(stat, np.ndarray), type(stat)
case _:
pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(stat)}")
return stat


@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Matrix})
@pytest.mark.parametrize("func", STAT_FUNCS)
@pytest.mark.parametrize(("ndim", "axis"), [(1, 0), (2, 3), (2, -1)], ids=["1d-ax0", "2d-ax3", "2d-axneg"])
Expand All @@ -127,26 +149,13 @@ def test_sum(
axis: Literal[0, 1] | None,
np_arr: NDArray[DTypeIn],
) -> None:
if np.dtype(dtype_arg).kind in "iu" and (array_type.flags & Flags.Gpu) and (array_type.flags & Flags.Sparse):
pytest.skip("GPU sparse matrices don’t support int dtypes")
arr = array_type(np_arr.copy())
assert arr.dtype == dtype_in

sum_ = stats.sum(arr, axis=axis, dtype=dtype_arg)

match axis, arr:
case _, types.DaskArray():
assert isinstance(sum_, types.DaskArray), type(sum_)
sum_ = sum_.compute() # type: ignore[assignment]
if isinstance(sum_, types.CupyArray):
sum_ = sum_.get()
case None, _:
assert isinstance(sum_, np.floating | np.integer), type(sum_)
case 0 | 1, types.CupyArray() | types.CupyCSRMatrix() | types.CupyCSCMatrix():
assert isinstance(sum_, types.CupyArray), type(sum_)
sum_ = sum_.get()
case 0 | 1, _:
assert isinstance(sum_, np.ndarray), type(sum_)
case _:
pytest.fail(f"Unhandled case axis {axis} for {type(arr)}: {type(sum_)}")
sum_ = to_np_dense_checked(sum_, axis, arr) # type: ignore[arg-type]

assert sum_.shape == () if axis is None else arr.shape[axis], (sum_.shape, arr.shape)

Expand All @@ -161,6 +170,19 @@ def test_sum(
np.testing.assert_array_equal(sum_, expected)


@pytest.mark.array_type(skip={*ATS_SPARSE_DS, Flags.Gpu})
def test_sum_to_int(array_type: ArrayType[CpuArray | DiskArray | types.DaskArray], axis: Literal[0, 1] | None) -> None:
rng = np.random.default_rng(0)
np_arr = rng.random((100, 100))
arr = array_type(np_arr)

sum_ = stats.sum(arr, axis=axis, dtype=np.int64)
sum_ = to_np_dense_checked(sum_, axis, arr)

expected = np.zeros(() if axis is None else arr.shape[axis], dtype=np.int64)
np.testing.assert_array_equal(sum_, expected)


@pytest.mark.parametrize(
"data",
[
Expand Down
3 changes: 2 additions & 1 deletion typings/cupy/_core/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ from typing import Any, Literal, Self, overload
import numpy as np
from cupy.cuda import Stream
from numpy._core.multiarray import flagsobj
from numpy.typing import NDArray
from numpy.typing import DTypeLike, NDArray

class ndarray:
dtype: np.dtype[Any]
Expand Down Expand Up @@ -41,6 +41,7 @@ class ndarray:
def flatten(self, order: Literal["C", "F", "A", "K"] = "C") -> Self: ...
@property
def flat(self) -> _FlatIter: ...
def sum(self, axis: int | None = None, dtype: DTypeLike | None = None, out: ndarray | None = None, keepdims: bool = False) -> ndarray: ...

class _FlatIter:
def __next__(self) -> np.float32 | np.float64: ...
Expand Down
1 change: 1 addition & 0 deletions typings/cupyx/scipy/sparse/_compressed.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ class _compressed_sparse_matrix(spmatrix):

# methods
def power(self, n: int, dtype: DTypeLike | None = None) -> Self: ...
def sum(self, axis: Literal[0, 1] | None = None, dtype: DTypeLike | None = None, out: Self | None = None) -> ndarray: ...
Loading