Skip to content

perf: speed up scanpy.get.aggregate for sparse#4041

Closed
ilan-gold wants to merge 16 commits intoig/two_pass_hvg_v3from
ig/numba_aggregate
Closed

perf: speed up scanpy.get.aggregate for sparse#4041
ilan-gold wants to merge 16 commits intoig/two_pass_hvg_v3from
ig/numba_aggregate

Conversation

@ilan-gold
Copy link
Copy Markdown
Contributor

@ilan-gold ilan-gold commented Apr 8, 2026

A nice testing script:

Details
# /// script
# requires-python = ">=3.12"
# dependencies = [
#   "numba",
#   "fast-array-utils[accel,sparse]",
#   "scipy",
#   "numpy"
# ]
# ///
#
# This script automatically imports the development branch of zarr to check for issues

from __future__ import annotations

import time

import numba
import numpy as np
from fast_array_utils.numba import njit
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, random


@njit
def agg_sum_csr(  # noqa: D103
    indicator: csr_matrix,
    data: csr_matrix,
):
    out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")
    for cat_num in numba.prange(indicator.shape[0]):
        start_cat_idx = indicator.indptr[cat_num]
        stop_cat_idx = indicator.indptr[cat_num + 1]
        for row_num in range(start_cat_idx, stop_cat_idx):
            obs_per_cat = indicator.indices[row_num]

            start_obs = data.indptr[obs_per_cat]
            end_obs = data.indptr[obs_per_cat + 1]

            for j in range(start_obs, end_obs):
                col = data.indices[j]
                out[cat_num, col] += float(data.data[j])
    return out


@njit
def agg_sum_csc(
    indicator: csr_matrix,
    data: csc_matrix,
):
    out = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64")

    # Precompute: observation → category
    obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64)

    for cat in range(indicator.shape[0]):
        for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]):
            obs_to_cat[indicator.indices[k]] = cat

    # Now iterate CSC efficiently
    for col in numba.prange(data.shape[1]):
        start = data.indptr[col]
        end = data.indptr[col + 1]

        for j in range(start, end):
            obs = data.indices[j]
            cat = obs_to_cat[obs]

            if cat != -1:
                out[cat, col] += float(data.data[j])

    return out


mat = random(70_000, 50_000, density=0.02, format="csr", rng=np.random.default_rng())
categories = np.random.randint(0, 20, size=mat.shape[0])

rows = categories
cols = np.arange(mat.shape[0])
data = np.ones(mat.shape[0], dtype=int)
membership_matrix = coo_matrix(
    (data, (categories, cols)), shape=(20, mat.shape[0])
).tocsr()

agg_sum_csr(membership_matrix, mat)
agg_sum_csc(membership_matrix, mat.tocsc())

# NOW THE FUNCTION IS COMPILED, RE-TIME IT EXECUTING FROM CACHE
start = time.time()
agg_sum_csr(membership_matrix, mat)
end = time.time()
print("numba csr time = %s" % (end - start))

start = time.time()
agg_sum_csc(membership_matrix, mat.tocsc())
end = time.time()
print("numba csr->csc time = %s" % (end - start))

csc_mat = mat.tocsc()
start = time.time()
agg_sum_csr(membership_matrix, csc_mat.tocsr())
end = time.time()
print("numba csc->csr time = %s" % (end - start))

csc_mat = mat.tocsc()
start = time.time()
agg_sum_csc(membership_matrix, csc_mat)
end = time.time()
print("numba csc time = %s" % (end - start))

start = time.time()
(membership_matrix @ mat).toarray()
end = time.time()
print("mul time = %s" % (end - start))

For me locally, even if I use numba's njit (i.e., parallel=False), this is still faster than multiplication (the current state of things). This means that things will probably also be faster for dask. It is not worth it to convert formats ever.

  • Closes #
  • Tests included or not required because:

@ilan-gold ilan-gold changed the title feat: speed up numba sums feat: speed up scanpy.get.aggregate summation Apr 8, 2026
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 8, 2026

Codecov Report

❌ Patch coverage is 97.22222% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 78.53%. Comparing base (17be530) to head (baef45c).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
src/scanpy/get/_aggregated.py 96.77% 1 Missing ⚠️
Additional details and impacted files
@@                  Coverage Diff                   @@
##           ig/two_pass_hvg_v3    #4041      +/-   ##
======================================================
+ Coverage               77.85%   78.53%   +0.68%     
======================================================
  Files                     117      118       +1     
  Lines                   12774    12790      +16     
======================================================
+ Hits                     9945    10045     +100     
+ Misses                   2829     2745      -84     
Flag Coverage Δ
hatch-test.low-vers 77.84% <97.22%> (+0.02%) ⬆️
hatch-test.pre 78.42% <97.22%> (+57.91%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
src/scanpy/get/_kernels.py 100.00% <100.00%> (ø)
src/scanpy/get/_aggregated.py 93.75% <96.77%> (+4.07%) ⬆️

... and 8 files with indirect coverage changes

@scverse-benchmark
Copy link
Copy Markdown

scverse-benchmark bot commented Apr 8, 2026

Benchmark changes

Change Before [17be530] After [baef45c] Ratio Benchmark (Parameter)
+ 56.1±1ms 142±30ms 2.53 preprocessing_log.PreprocessingSuite.time_pca('pbmc68k_reduced', 'off-axis')

Comparison: https://github.com/scverse/scanpy/compare/17be530c8822cbaeb629ffd41175a4558031f7eb..baef45cef7f22ef8329b2dca94381dcc8343d5aa
Last changed: Fri, 10 Apr 2026 14:53:31 +0000

More details: https://github.com/scverse/scanpy/pull/4041/checks?check_run_id=70786088065

@ilan-gold ilan-gold force-pushed the ig/numba_aggregate branch from f4fa89b to febc107 Compare April 8, 2026 18:23
@ilan-gold ilan-gold changed the title feat: speed up scanpy.get.aggregate summation feat: speed up scanpy.get.aggregate for sparse Apr 10, 2026
@ilan-gold ilan-gold changed the title feat: speed up scanpy.get.aggregate for sparse perf: speed up scanpy.get.aggregate for sparse Apr 10, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant