Rank refactor 1#4118
Conversation
zboldyga
commented
May 11, 2026
- Closes #
- Tests included or not required because:
- Release notes not necessary because:
|
@ilan-gold here's a proof of concept for the stats speedup. The stats are the biggest performance improvement remaining on the scanpy side of the illico integration -- here's the total scanpy illico time before vs. after the patch. Note that this is only relevant to vs_rest mode (all other cells). Using an individual group as a reference was already fine. vs_rest
I used aggregate as you mentioned. That said, two additional points:
There's a better algorithm for calculating variance that doesn't have these issues: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm . I did an implementation of this (didn't commit), and in my initial tests it was roughly the same speed as this current approach using 'aggregate' (possibly 20% faster but too early to say). I would need to think more carefully about where this fits in scanpy, e.g. it might be best as a util in get alongside aggregate, or a replacement for aggregate. So that in itself is a separate issue, perhaps it needs to be addressed before we can finish this basic stats speedup work... e.g. with that in place, I can simplify this code a bit more, and we avoid numerical stability issues. (note that this issue already exist in 'aggregate').
Thoughts on these 2 points and the current PR? |
|
So to your points
|
With the cast-to-float64 fix in aggregate's variance computation
(arriving via main rebase from PR aggregate-welford), three
workarounds in _rank_genes_groups can go:
1. The float32 cast-back in _aggregate_group_stats that downgraded
aggregate's float64 output to match legacy mean_var precision.
2. The _compute_rest_stats_for_t_test slow path that recomputed
vars_rest via direct mean_var(X[~mask]) because the previous
sum-decomp couldn't produce accurate-enough values.
3. The previously zero-initialized vars_rest in _derive_rest_stats
is now computed via sum-decomp from group/global totals, with a
max(var, 0) clamp for the all-values-equal cancellation edge
case (mirrors the band-aid in Aggregate.mean_var).
Net effect: dropped ~25 lines of dead workaround, simpler control
flow in compute_statistics. Existing 70 tests pass (224 subtests).
One test-tolerance bump: added atol=1e-10 alongside the existing
rtol=1e-5 in test_results' score assertion. The new code produces
sub-machine-precision noise (~1e-15) at a position where the legacy
path produced exact 0.0 (one gene where group and rest means are
equal). Both represent the same mathematical zero; atol accepts both
without weakening the non-zero-score tolerance.
This commit assumes the kernel fix is in scanpy main. Until that
merges, this branch's CI may fail; rebase on main pulls in the fix.
for more information, see https://pre-commit.ci
❌ 5 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
|
am still working this one, simplifying and limiting the scope properly. will push some more changes tomorrow most likely to wrap it up. |
| sum_total = sum_g.sum(0) if mask_all else np.asarray(X.sum(axis=0)).ravel() | ||
| self.means_rest = (sum_total - sum_g) / n_r | ||
|
|
||
| # TODO: if `aggregate` exposed `sum_of_squares` (an additive |
There was a problem hiding this comment.
@ilan-gold thoughts on this?
Basically we need to compute the variance for the 'rest' sets somehow for t test. I kept it as old code for now, to minimize surface area. So it's slow, and mean_var does suffer from the same (rare) 'catastrophic cancellation' situation.
We could defer this to another PR to limit scope now
I figured I'd bring to your attention because exposing the sum_of_square values in aggregate enables us to do a quicker var computation here
My impression is it didn't seem like aggregate itself would be something we'd want to manipulate to allow the vs_rest logic, so it seems like for the vs_rest stats we're doing the computation in here (and exposing sum of squares reduces computation)
Additionally, we could also abstract some of the chan stuff into a small util and use that both in aggregate and here?
There was a problem hiding this comment.
If we implemented Welford's algorithm inside https://github.com/scverse/fast-array-utils/ for mean_var, would that take care of the catastrophic cancellation potential? Then it seems like we wouldn't have to worry about the numeric stability here.
So if I have that right, I would think that the route forward would be then, as you say, put that in a separate PR there and then we'll automatically start benefitting from it everywhere in the scanpy codebase.