Skip to content

Fix float64 precision loss in sparse Pearson residual kernels#658

Merged
Intron7 merged 4 commits into
scverse:mainfrom
Arshammik:fix/pearson-residuals-sqrtf-precision
May 15, 2026
Merged

Fix float64 precision loss in sparse Pearson residual kernels#658
Intron7 merged 4 commits into
scverse:mainfrom
Arshammik:fix/pearson-residuals-sqrtf-precision

Conversation

@Arshammik
Copy link
Copy Markdown
Contributor

Summary

The CSR and CSC Pearson residual kernels in src/rapids_singlecell/_cuda/pr/kernels_pr.cuh divide by sqrtf, the single-precision square-root intrinsic. Both kernels are templated on the element type T, so a T=double instantiation silently narrows the variance term mu + mu * mu * inv_theta to float32, evaluates the square root at single precision, and promotes the result back to double.

As a result, the float64 path of pp.normalize_pearson_residuals (and pp.highly_variable_genes with flavor='pearson_residuals') was capped at ~7 significant digits regardless of the requested dtype. The dense kernel dense_norm_res_kernel already uses the overloaded sqrt and was unaffected.

The fix replaces sqrtf with the overloaded sqrt on both sparse paths. sqrt dispatches to the single-precision root for T=float and the double-precision root for T=double, so the float32 path is byte-identical to before and only the float64 path changes.

Verification

A standalone harness compiled the real sparse_norm_res_csr_kernel verbatim and ran it on a 4000_4000 synthetic CSR count matrix against a host float64 reference (NVIDIA H100 80GB HBM3, CUDA 12.6, sm_90):

dtype before fix after fix
double max rel. error 8.83e-08 (~7.1 digits) max rel. error 3.97e-16 (~15.4 digits)
float _ bit-identical (max abs diff 0.0)

The float64 path is now ~8 orders of magnitude more accurate; the float32 path is provably unchanged.

Changes

  • _cuda/pr/kernels_pr.cuh _ sqrtf _ sqrt in sparse_norm_res_csc_kernel and sparse_norm_res_csr_kernel.
  • tests/test_normalization.py _ adds test_normalize_pearson_residuals_float64_precision, which pins the float64 CSR/CSC output to an analytic float64 reference at rtol/atol 1e-9 (tight enough to fail on a single-precision result, pass on a genuine float64 one) across theta in {100, inf}.
  • docs/release-notes/0.15.1.md _ bug-fix entry.

The CSR and CSC Pearson residual kernels in _cuda/pr/kernels_pr.cuh
divided by `sqrtf`, the single-precision square-root intrinsic. Because
both kernels are templated on the element type `T`, a `T=double`
instantiation silently narrowed the variance term
`mu + mu * mu * inv_theta` to float32, evaluated the square root at
single precision, and promoted the result back to double. The float64
path of `pp.normalize_pearson_residuals` (and `pp.highly_variable_genes`
with `flavor='pearson_residuals'`) was therefore capped at ~7
significant digits regardless of the requested dtype. The dense kernel
`dense_norm_res_kernel` already used the overloaded `sqrt` and was
unaffected.

Replace `sqrtf` with the overloaded `sqrt` on both sparse paths. `sqrt`
dispatches to the single-precision root for `T=float` and the
double-precision root for `T=double`, so the float32 path is
byte-identical to before and only the float64 path changes.

Hardware verification (NVIDIA H100 80GB HBM3, CUDA 12.6, sm_90):
A standalone harness compiled the real `sparse_norm_res_csr_kernel`
verbatim and ran it on a 4000x4000 synthetic CSR count matrix against a
host float64 reference.

  T=double, before fix:   max relative error 8.83e-08  (~7.1 digits)
  T=double, after fix:    max relative error 3.97e-16  (~15.4 digits)
  T=float,  before/after: bit-identical (max abs diff 0.0)

The float64 path is now ~8 orders of magnitude more accurate; the
float32 path is provably unchanged.

Add `test_normalize_pearson_residuals_float64_precision` to
tests/test_normalization.py. It pins the float64 CSR/CSC output to an
analytic float64 reference at rtol/atol 1e-9 -- tight enough to fail on
a single-precision result and pass on a genuine float64 one -- across
theta in {100, inf}.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 15, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 88.16%. Comparing base (61eda66) to head (2479eff).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #658      +/-   ##
==========================================
+ Coverage   88.13%   88.16%   +0.02%     
==========================================
  Files          96       96              
  Lines        7045     7045              
==========================================
+ Hits         6209     6211       +2     
+ Misses        836      834       -2     

see 1 file with indirect coverage changes

@Intron7
Copy link
Copy Markdown
Member

Intron7 commented May 15, 2026

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 15, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 15, 2026

Warning

Rate limit exceeded

@Intron7 has exceeded the limit for the number of commits that can be reviewed per hour. Please wait 59 minutes and 22 seconds before requesting another review.

You’ve run out of usage credits. Purchase more in the billing tab.

⌛ How to resolve this issue?

After the wait time has elapsed, a review can be triggered using the @coderabbitai review command as a PR comment. Alternatively, push new commits to this PR.

We recommend that you space out your commits to avoid hitting the rate limit.

🚦 How do rate limits work?

CodeRabbit enforces hourly rate limits for each developer per organization.

Our paid plans have higher rate limits than the trial, open-source and free plans. In all cases, we re-allow further reviews after a brief timeout.

Please see our FAQ for further information.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1c26cb94-5177-429f-994f-0ce025a9cf1a

📥 Commits

Reviewing files that changed from the base of the PR and between 2afa3c6 and 949aba4.

📒 Files selected for processing (3)
  • docs/release-notes/0.15.1.md
  • src/rapids_singlecell/_cuda/pr/kernels_pr.cuh
  • tests/test_normalization.py
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 15, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@Intron7 Intron7 enabled auto-merge (squash) May 15, 2026 09:46
@Intron7 Intron7 disabled auto-merge May 15, 2026 10:09
@Intron7
Copy link
Copy Markdown
Member

Intron7 commented May 15, 2026

Thanks for tracking this down. I tested it locally and want to fold in one tweak before merging.
The PR's sqrtf → sqrt swap is correct but carries a measurable fp64 cost on consumer-class Blackwell silicon (RTX PRO 6000, 1:64 fp64):

┌─────────────┬────────┬───────┬────────────────┬────────────────┬────────────────┐
│    Size     │ Format │ dtype │ sqrtf (broken) │ sqrt (your PR) │ *= rsqrt(...)  │
├─────────────┼────────┼───────┼────────────────┼────────────────┼────────────────┤
│ 50k × 5k    │ CSR    │ fp64  │ 79.4 ms        │ 86.1 ms (+8%)  │ 75.0 ms (−6%)  │
├─────────────┼────────┼───────┼────────────────┼────────────────┼────────────────┤
│ 50k × 5k    │ CSC    │ fp64  │ 52.6 ms        │ 60.0 ms (+14%) │ 47.3 ms (−10%) │
├─────────────┼────────┼───────┼────────────────┼────────────────┼────────────────┤
│ 200k × 2k   │ CSR    │ fp64  │ 43.1 ms        │ 53.2 ms (+23%) │ 36.4 ms (−16%) │
├─────────────┼────────┼───────┼────────────────┼────────────────┼────────────────┤
│ 200k × 2k   │ CSC    │ fp64  │ 197 ms         │ 227 ms (+15%)  │ 176 ms (−11%)  │
├─────────────┼────────┼───────┼────────────────┼────────────────┼────────────────┤
│ (fp32 rows) │ —      │ —     │ —              │ identical      │ 1–5% faster    │
└─────────────┴────────┴───────┴────────────────┴────────────────┴────────────────┘

So I replaced

  residuals[res_index] -= mu;
  residuals[res_index] /= sqrt(mu + mu * mu * inv_theta);

with

  residuals[res_index] -= mu;
  residuals[res_index] *= rsqrt(mu + mu * mu * inv_theta);

One small correction
The PR description mentions pp.highly_variable_genes(flavor='pearson_residuals') being affected — actually that path runs through kernels_pr_hvg.cuh, which has been using rsqrt all along, so HVG outputs are unchanged. The fix is pp.normalize_pearson_residuals only.

Thank again for the PR and the great find

@Intron7 Intron7 enabled auto-merge (squash) May 15, 2026 10:21
@Intron7 Intron7 merged commit 42123b3 into scverse:main May 15, 2026
18 of 19 checks passed
@Arshammik
Copy link
Copy Markdown
Contributor Author

NP! Happy to help!

@Arshammik Arshammik deleted the fix/pearson-residuals-sqrtf-precision branch May 15, 2026 19:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants