Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
"scikit-misc": [""],
},
"default_benchmark_timeout": 500,

// Combinations of libraries/python versions can be excluded/included
// from the set to test. Each entry is a dictionary containing additional
Expand Down
8 changes: 5 additions & 3 deletions benchmarks/benchmarks/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,12 @@ def bmmc(n_obs: int = 400) -> AnnData:

@cache
def _lung93k() -> AnnData:
path = pooch.retrieve(
url="https://figshare.com/ndownloader/files/45788454",
known_hash="md5:4f28af5ff226052443e7e0b39f3f9212",
registry = pooch.create(
path=pooch.os_cache("pooch"),
base_url="doi:10.6084/m9.figshare.25664775.v1/",
)
registry.load_registry_from_doi()
path = registry.fetch("adata.raw_compressed.h5ad")
adata = sc.read_h5ad(path)
assert isinstance(adata.X, CSRBase)
adata.layers["counts"] = adata.X.astype(np.int32, copy=True)
Expand Down
78 changes: 44 additions & 34 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@

import numpy as np
import pandas as pd
from anndata import AnnData, utils
from anndata import AnnData
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr
from .get import _check_mask

if TYPE_CHECKING:
Expand All @@ -25,7 +26,7 @@
type AggType = ConstantDtypeAgg | Literal["mean", "var"]


class Aggregate:
class Aggregate[ArrayT: np.ndarray | CSBase]:
"""Functionality for generic grouping and aggregating.

There is currently support for count_nonzero, sum, mean, and variance.
Expand Down Expand Up @@ -53,7 +54,7 @@ class Aggregate:
def __init__(
self,
groupby: pd.Categorical,
data: Array,
data: ArrayT,
*,
mask: NDArray[np.bool] | None = None,
) -> None:
Expand All @@ -64,8 +65,8 @@ def __init__(
self.data = data

groupby: pd.Categorical
indicator_matrix: sparse.coo_matrix
data: Array
indicator_matrix: CSRBase
data: ArrayT

def count_nonzero(self) -> NDArray[np.integer]:
"""Count the number of observations in each group.
Expand All @@ -75,19 +76,30 @@ def count_nonzero(self) -> NDArray[np.integer]:
Array of counts.

"""
# pattern = self.data._with_data(np.broadcast_to(1, len(self.data.data)))
# return self.indicator_matrix @ pattern
return utils.asarray(self.indicator_matrix @ (self.data != 0))
return self._sum(data=(self.data != 0).astype("uint8"))

def _sum(self, data: ArrayT):
if isinstance(data, np.ndarray):
res = self.indicator_matrix @ data
if isinstance(res, CSBase):
return res.toarray()
return res
dtype = np.int64 if np.issubdtype(data.dtype, np.integer) else np.float64
out = np.zeros((self.indicator_matrix.shape[0], data.shape[1]), dtype=dtype)
(agg_sum_csr if isinstance(data, CSRBase) else agg_sum_csc)(
self.indicator_matrix, data, out
)
return out

def sum(self) -> Array:
def sum(self) -> np.ndarray:
"""Compute the sum per feature per group of observations.

Returns
-------
Array of sum.

"""
return utils.asarray(self.indicator_matrix @ self.data)
return self._sum(self.data)

def mean(self) -> Array:
"""Compute the mean per feature per group of observations.
Expand All @@ -97,10 +109,7 @@ def mean(self) -> Array:
Array of mean.

"""
return (
utils.asarray(self.indicator_matrix @ self.data)
/ np.bincount(self.groupby.codes)[:, None]
)
return self.sum() / np.bincount(self.groupby.codes)[:, None]

def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
"""Compute the count, as well as mean and variance per feature, per group of observations.
Expand All @@ -124,14 +133,17 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
assert dof >= 0

group_counts = np.bincount(self.groupby.codes)
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = (
utils.asarray(self.indicator_matrix @ _power(self.data, 2))
/ group_counts[:, None]
)
sq_mean = mean_**2
var_ = mean_sq - sq_mean
if isinstance(self.data, np.ndarray):
mean_ = self.mean()
# sparse matrices do not support ** for elementwise power.
mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None]
sq_mean = mean_**2
var_ = mean_sq - sq_mean
else:
mean_, var_ = (
mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc
)(self.indicator_matrix, self.data)
sq_mean = mean_**2
# TODO: Why these values exactly? Because they are high relative to the datatype?
# (unchanged from original code: https://github.com/scverse/anndata/pull/564)
precision = 2 << (42 if self.data.dtype == np.float64 else 20)
Expand Down Expand Up @@ -550,18 +562,16 @@ def sparse_indicator(
categorical: pd.Categorical,
*,
mask: NDArray[np.bool] | None = None,
weight: NDArray[np.floating] | None = None,
) -> sparse.coo_matrix:
if mask is not None and weight is None:
weight = mask.astype(np.float32)
elif mask is not None and weight is not None:
weight = mask * weight
elif mask is None and weight is None:
weight = np.broadcast_to(1.0, len(categorical))
) -> CSRBase:
if mask is None:
# TODO: why is this float64. This is a scanpy 2.0 problem maybe?
mask = np.broadcast_to(1.0, len(categorical))
else:
mask = mask.astype("uint8")
# can’t have -1s in the codes, but (as long as it’s valid), the value is ignored, so set to 0 where masked
codes = categorical.codes if mask is None else np.where(mask, categorical.codes, 0)
a = sparse.coo_matrix(
(weight, (codes, np.arange(len(categorical)))),
codes = np.where(mask, categorical.codes, 0)
a = sparse.coo_array(
(mask, (codes, np.arange(len(categorical)))),
shape=(len(categorical.categories), len(categorical)),
)
).tocsr()
return a
119 changes: 119 additions & 0 deletions src/scanpy/get/_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numba
import numpy as np
from fast_array_utils.numba import njit

if TYPE_CHECKING:
from numpy.typing import NDArray

from .._compat import CSCBase, CSRBase


@njit
def agg_sum_csr(indicator: CSRBase, data: CSRBase, out: NDArray):
for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
for row_num in range(start_cat_idx, stop_cat_idx):
obs_per_cat = indicator.indices[row_num]

start_obs = data.indptr[obs_per_cat]
end_obs = data.indptr[obs_per_cat + 1]

for j in range(start_obs, end_obs):
col = data.indices[j]
out[cat_num, col] += data.data[j]


@njit
def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray):

obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

for cat in range(indicator.shape[0]):
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
obs_to_cat[indicator.indices[k]] = cat

for col in numba.prange(data.shape[1]):
start = data.indptr[col]
end = data.indptr[col + 1]

for j in range(start, end):
obs = data.indices[j]
cat = obs_to_cat[obs]

if cat != -1:
out[cat, col] += data.data[j]


@njit
def mean_var_csr(
indicator: CSRBase,
data: CSCBase,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:

mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
for row_num in range(start_cat_idx, stop_cat_idx):
obs_per_cat = indicator.indices[row_num]

start_obs = data.indptr[obs_per_cat]
end_obs = data.indptr[obs_per_cat + 1]

for j in range(start_obs, end_obs):
col = data.indices[j]
value = np.float64(data.data[j])
value = data.data[j]
mean[cat_num, col] += value
var[cat_num, col] += value * value

n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
return mean, var


@njit
def mean_var_csc(
indicator: CSRBase, data: CSCBase
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:

obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

for cat in range(indicator.shape[0]):
for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
obs_to_cat[indicator.indices[k]] = cat

for col in numba.prange(data.shape[1]):
start = data.indptr[col]
end = data.indptr[col + 1]

for j in range(start, end):
obs = data.indices[j]
cat = obs_to_cat[obs]

if cat != -1:
value = np.float64(data.data[j])
value = data.data[j]
mean[cat, col] += value
var[cat, col] += value * value

for cat_num in numba.prange(indicator.shape[0]):
start_cat_idx = indicator.indptr[cat_num]
stop_cat_idx = indicator.indptr[cat_num + 1]
n_obs = stop_cat_idx - start_cat_idx
mean_cat = mean[cat_num, :] / n_obs
mean[cat_num, :] = mean_cat
var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat)
return mean, var
Loading