Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3872.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add in `csc`-in-{doc}`dask:index` support for {func}`scanpy.get.aggregate` {smaller}`I Gold`
50 changes: 33 additions & 17 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSCBase, CSRBase, DaskArray
from scanpy._compat import CSBase, CSRBase, DaskArray

from .._utils import _resolve_axis, get_literal_vals
from .get import _check_mask
Expand Down Expand Up @@ -357,9 +357,6 @@ def aggregate_dask_mean_var(
# TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse.
if isinstance(data._meta, CSRBase):
sq_mean = sq_mean.compute()
elif isinstance(data._meta, CSCBase): # pragma: no-cover
msg = "Cannot handle CSC matrices as dask meta."
raise ValueError(msg)
var = sq_mean - fau_power(mean, 2)
if dof != 0:
group_counts = np.bincount(by.codes)
Expand All @@ -376,47 +373,66 @@ def aggregate_dask(
mask: NDArray[np.bool_] | None = None,
dof: int = 1,
) -> dict[AggType, DaskArray]:
if not isinstance(data._meta, CSRBase | np.ndarray):
if not isinstance(data._meta, CSBase | np.ndarray):
msg = f"Got {type(data._meta)} meta in DaskArray but only csr_matrix/csr_array and ndarray are supported."
raise ValueError(msg)
if data.chunksize[1] != data.shape[1]:
chunked_axis, unchunked_axis = (
(0, 1) if isinstance(data._meta, CSRBase | np.ndarray) else (1, 0)
)
if data.chunksize[unchunked_axis] != data.shape[unchunked_axis]:
msg = "Feature axis must be unchunked"
raise ValueError(msg)

def aggregate_chunk_sum_or_count_nonzero(
chunk: Array, *, func: Literal["count_nonzero", "sum"], block_info=None
):
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
# only subset the mask and by if we need to i.e.,
# there is chunking along the same axis as by and mask
if chunked_axis == 0:
# See https://docs.dask.org/en/stable/generated/dask.array.map_blocks.html
# for what is contained in `block_info`.
subset = slice(*block_info[0]["array-location"][0])
by_subsetted = by[subset]
mask_subsetted = mask[subset] if mask is not None else mask
else:
by_subsetted = by
mask_subsetted = mask
res = _aggregate(chunk, by_subsetted, func, mask=mask_subsetted, dof=dof)[func]
return res[None, :]
return res[None, :] if unchunked_axis == 1 else res

funcs = set([func] if isinstance(func, str) else func)
if "median" in funcs:
msg = "Dask median calculation not supported. If you want a median-of-medians calculation, please open an issue."
raise NotImplementedError(msg)
has_mean, has_var = (v in funcs for v in ["mean", "var"])
funcs_no_var_or_mean = funcs - {"var", "mean"}
# aggregate each row chunk individually,
# producing a #chunks × #categories × #features array,
# aggregate each row chunk or column chunk individually,
# producing a #chunks × #categories × #features or a #categories × #chunks array,
# then aggregate the per-chunk results.
chunks = (
((1,) * data.blocks.size, (len(by.categories),), data.shape[1])
if unchunked_axis == 1
else (len(by.categories), data.chunks[1])
)
aggregated = {
f: data.map_blocks(
partial(aggregate_chunk_sum_or_count_nonzero, func=func),
new_axis=(1,),
chunks=((1,) * data.blocks.size, (len(by.categories),), (data.shape[1],)),
new_axis=(1,) if unchunked_axis == 1 else None,
chunks=chunks,
meta=np.array(
[],
dtype=np.float64
if func not in get_args(ConstantDtypeAgg)
else data.dtype, # TODO: figure out best dtype for aggs like sum where dtype can change from original
),
).sum(axis=0)
)
for f in funcs_no_var_or_mean
}
# If we have row chunking, we need to handle the extra axis by summing over all category × feature matrices.
# Otherwise, dask internally concatenates the #categories × #chunks arrays i.e., the column chunks are concatenated together to get a #categories × #features matrix.
if unchunked_axis == 1:
for k, v in aggregated.items():
aggregated[k] = v.sum(axis=chunked_axis)
if has_var:
aggredated_mean_var = aggregate_dask_mean_var(data, by, mask=mask, dof=dof)
aggregated["var"] = aggredated_mean_var["var"]
Expand Down
3 changes: 0 additions & 3 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
not in {
"dask_array_dense",
"dask_array_sparse",
"dask_array_sparse-1d_chunked-csc_array",
"dask_array_sparse-1d_chunked-csc_matrix",
}
]

Expand Down Expand Up @@ -246,7 +244,6 @@ def to_csc(x: CSRBase):
@pytest.mark.parametrize(
("func", "error_msg"),
[
pytest.param(to_csc, r"only csr_matrix", id="csc"),
pytest.param(
to_bad_chunking, r"Feature axis must be unchunked", id="bad_chunking"
),
Expand Down
Loading