fix float32 precision loss in aggregate variance: cast to float64 before squaring#4122
fix float32 precision loss in aggregate variance: cast to float64 before squaring#4122zboldyga wants to merge 2 commits into
Conversation
…ore squaring In all three variance paths of `sc.get.aggregate` — the sparse CSR/CSC kernels, the dense branch of `Aggregate.mean_var`, and the dask `aggregate_dask_mean_var` — per-element squaring was happening at the input dtype. For float32 input (scanpy's default for log1p-normalized X), each squared value carried only ~7 digits of precision before being promoted into the float64 accumulator. The downstream `mean_sq - sq_mean` cancellation then amplifies that accumulated absolute error by the `mean²/var` ratio. This commit makes intermediate squaring happen at float64 across all three paths: - sparse kernels: remove the dead `value = data.data[j]` line that was overwriting the existing `np.float64(...)` cast (so the cast on the prior line actually applies); - dense path: cast `self.data` to float64 before `_power(..., 2)`; - dask path: cast `data` to float64 before `fau_power(..., 2)` (lazy, so no eager copy). Reduces synthetic-adversarial variance error from ~3.5e-4 to ~1.6e-10 on the sparse kernel (6 OOM). On real perturb-seq pseudobulk-of- pseudobulks workloads, error drops 30-60× from ~2e-5 to ~5e-7. Perf impact: sparse paths are essentially free (within ±5% noise). Dense paths take ~10-26% longer per call; dask paths ~27-46% longer. The full table is at de-optimization/welford_explore/results/aggregate_fix_perf.md.
for more information, see https://pre-commit.ci
|
Lengthy AI writeup that attempts to explain the underlying stability issue with the float32 squaring in this data regime 😅 Why float32 squaring is problematic for scanpy dataThe mathVariance is computed as Specifically: each input to that subtraction ( Concrete numbers
For a ribosomal gene with
That's the 9-orders-of-magnitude difference between the two regimes. The fix is one cast — Why scanpy's typical data lives in this regimescanpy stores log1p-normalized counts as float32 by default,
Pseudobulk-of-pseudobulks workflows (the documented use case for What this looks like empiricallyOn the wessels23 RPS15 pseudobulk scenario (real data,
Predictions overshoot observed values by 10–100× because real The 5e-7 residual after the fix isn't kernel error anymore — it's |
|
And some evidence of the resolution via float64 usage (AI writeup again here, lengthy data). The gist is that this was impacting the results enough that it would have caused a regression if aggregated was used in wilcoxon-illico stats computation. And I'd say it is likely the float32 squaring would cause quite tangible issues when dealing with other use cases like pseudobulk-related things, manipulated data that is not standard log1p counts, or maybe newer datasets with higher library sizes. Reference Note"Correct" is defined as two-pass shifted-mean variance computed at Numerical impactSynthetic adversarial scenariosEach row is a single
Reading: on the ribosomal scenario, the kernel returned variances that Real pseudobulk workloadsThe same measurement on real perturb-seq pseudobulk-of-pseudobulks
Reading: routine biological workflows (housekeeping/ribosomal/abundant Cross-backend confirmationThe bug is in the squaring step, which appears identically across all
Reading: identical errors across all four backends — same root cause, |
❌ 5 Tests Failed:
View the top 3 failed test(s) by shortest run time
To view more test analytics, go to the Test Analytics Dashboard |
I started one for dask, will post now. I think we should just replace the algorithm IMO Update: #4123 |
|
sweet ok I'll close this and am happy to review or contribute on that one as needed |
Your review would be extremely welcome!!!! |
While working on the wilcoxon illico integration, I ran into numerical stability issues with the 'aggregate' function.
Upon further inspection, the issues primarily stem from float32 squaring before the mean_of_sq - sq_of_mean variance calculation. It seems there was some code to cast to float64 e.g. in the csc and csr cases, but then the next line cast it right back to float32
value = data.data[j]. And other cases like the Dask and dense paths were not using float64.I'll add a writeup from AI in another message below about the error levels / empirical findings.
...
So this small fix solves a large part of the numerical stability problem. This is a sufficient-enough fix to enable the stats computation code in the wilcoxon illico work to be completed using 'aggregate', which speeds it up an OOM.
That said, 'aggregate' has further issues -- the variance calculation used here is known to have a 'catastrophic cancellation' scenario when the mean/var ratio is high, e.g. 10,000+. I will open another PR for this soon, but it's a much bigger change, probably requiring more testing and review.
Note: This uses float64 so it's a bit slower. I tested a few perturb-seq datasets and it was about 30% slower on the Dask path. But only about 5% slower on the CSR path, and ~10% on CSC. However, I am fairly sure once we add the welford's online approach (or similar) this will speed up.
@ilan-gold tagging you here since this resolves the issue in wilcoxon-illico work / unblocks completion of my PR on stats computation speedup