Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"pooch": [""],
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
// "scikit-misc": [""],
"dask": [""],
},

// Combinations of libraries/python versions can be excluded/included
Expand Down
32 changes: 25 additions & 7 deletions benchmarks/benchmarks/preprocessing_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import TYPE_CHECKING

import anndata as ad
import zarr

import scanpy as sc
from scanpy._utils import get_literal_vals
Expand Down Expand Up @@ -151,17 +152,34 @@ def peakmem_log1p(self, *_) -> None:


class Agg: # noqa: D101
params: tuple[AggType] = tuple(get_literal_vals(AggType))
param_names = ("agg_name",)
params: tuple[list[str], list[bool]] = (
list(get_literal_vals(AggType)),
[True, False],
)
param_names = ("agg_name", "use_dask")

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("lung93k")
adata.write_h5ad("lung93k.h5ad")

def setup(self, agg_name: AggType) -> None:
self.adata = ad.read_h5ad("lung93k.h5ad")
self.agg_name = agg_name
adata.write_zarr("lung93k.zarr")

def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001
if use_dask:
if agg_name == "median":
# Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions
raise NotImplementedError()
z = zarr.open("lung93k.zarr")
self.adata = ad.AnnData(
obs=ad.io.read_elem(z["obs"]),
var=ad.io.read_elem(z["var"]),
layers={
"counts": ad.experimental.read_elem_lazy(z["layers"]["counts"])
},
X=ad.experimental.read_elem_lazy(z["X"]),
)
else:
self.adata = ad.read_zarr("lung93k.zarr")
self.agg_name: AggType = agg_name

def time_agg(self, *_) -> None:
sc.get.aggregate(
Expand Down
134 changes: 123 additions & 11 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from fast_array_utils.stats._power import power as fau_power # TODO: upstream
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

Expand Down Expand Up @@ -371,16 +370,129 @@ def aggregate_dask_mean_var(
mask: NDArray[np.bool] | None = None,
dof: int = 1,
) -> MeanVarDict:
mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"]
sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"]
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
var *= (group_counts / (group_counts - dof))[:, np.newaxis]
return MeanVarDict(mean=mean, var=var)
"""Compute group-wise mean and variance for a dask array.

Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``),
then combine across chunks with the pairwise parallel algorithm from
Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm)
so the across-chunk reduction avoids the catastrophic cancellation of
``E[X**2] - E[X]**2``.
"""
import dask.array as da

n_categories = len(by.categories)
n_features = data.shape[1]
chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1

if chunked_axis == 1:
# Each block already sees every observation, so mean/var per chunk is final.
def per_block_col(chunk: Array) -> NDArray[np.float64]:
mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof)
return np.concatenate([mean_, var_], axis=0)

combined = data.map_blocks(
per_block_col,
chunks=((2 * n_categories,), data.chunks[1]),
meta=np.array([], dtype=np.float64),
)
return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:])

n_blocks = data.numblocks[0]

def per_block_row(
chunk: Array, block_info: dict | None = None
) -> NDArray[np.float64]:
row_subset = slice(*block_info[0]["array-location"][0])
by_sub = by[row_subset]
mask_sub = mask[row_subset] if mask is not None else None
return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[
None
]

per_block_stats = data.map_blocks(
per_block_row,
chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)),
new_axis=(1, 2),
meta=np.array([], dtype=np.float64),
)

combined = da.reduction(
per_block_stats,
chunk=lambda x, axis=None, keepdims=False: x,
aggregate=_chan_reduce_axis_0,
axis=0,
keepdims=False,
concatenate=True,
dtype=np.float64,
meta=np.array([], dtype=np.float64),
)
counts = combined[0]
mean_ = combined[1]
m2 = combined[2]
denom = counts - dof if dof > 0 else counts
return MeanVarDict(mean=mean_, var=m2 / denom)


def _block_moments(
data: np.ndarray | CSBase,
by: pd.Categorical,
*,
mask: NDArray[np.bool] | None,
n_categories: int,
) -> NDArray[np.float64]:
"""Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``.

Groups with no observations in the chunk get zeros for mean and M2 so
they combine cleanly under ``_chan_combine``.
"""
codes = by.codes
valid = codes >= 0
if mask is not None:
valid = valid & mask
counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64)

out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64)
out[0] = counts[:, None]
nonempty = counts > 0
if not nonempty.any():
return out

agg = Aggregate(groupby=by, data=data, mask=mask)
sum_ = agg.sum()
sum_sq = agg._sum(_power(data, 2))
safe_counts = np.where(nonempty, counts, 1)[:, None]
mean_ = sum_ / safe_counts
# M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0.
m2 = np.maximum(sum_sq - sum_ * mean_, 0)
out[1, nonempty] = mean_[nonempty]
out[2, nonempty] = m2[nonempty]
return out


def _chan_combine(
a: NDArray[np.float64], b: NDArray[np.float64]
) -> NDArray[np.float64]:
"""Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise."""
n_a, mean_a, m2_a = a[0], a[1], a[2]
n_b, mean_b, m2_b = b[0], b[1], b[2]
n = n_a + n_b
safe_n = np.where(n > 0, n, 1)
delta = mean_b - mean_a
new_mean = mean_a + delta * n_b / safe_n
new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n
return np.stack([n, new_mean, new_m2])


def _chan_reduce_axis_0(
stats: NDArray[np.float64],
axis: int | None,
keepdims: bool, # noqa: FBT001
) -> NDArray[np.float64]:
"""Aggregate per-block stats along axis 0 with the parallel variance algorithm."""
result = stats[0]
for i in range(1, stats.shape[0]):
result = _chan_combine(result, stats[i])
return result[None] if keepdims else result


@_aggregate.register(DaskArray)
Expand Down
Loading