Skip to content

add wasserstein to pertpy GPU Distance#683

Merged
Intron7 merged 7 commits into
mainfrom
add-wasserstein
Jun 2, 2026
Merged

add wasserstein to pertpy GPU Distance#683
Intron7 merged 7 commits into
mainfrom
add-wasserstein

Conversation

@Intron7
Copy link
Copy Markdown
Member

@Intron7 Intron7 commented Jun 1, 2026

This adds the Wasserstein metric to pertpy-GPU Distance. For larger Dataset this is up to 400x faster than Pertpy using Jax-GPU. It also fixed an issue with empty categories for edistance

@Intron7
Copy link
Copy Markdown
Member Author

Intron7 commented Jun 1, 2026

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 1, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 1, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b35cfc3b-3061-4d2c-82ac-11718161a657

📥 Commits

Reviewing files that changed from the base of the PR and between ab0d310 and 0070466.

📒 Files selected for processing (7)
  • docs/release-notes/0.15.2.md
  • src/rapids_singlecell/_cuda/sinkhorn/kernels_sinkhorn.cuh
  • src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu
  • src/rapids_singlecell/pertpy_gpu/_distance.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_sinkhorn.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py
  • tests/pertpy/test_distance_wasserstein.py
✅ Files skipped from review due to trivial changes (1)
  • docs/release-notes/0.15.2.md
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/rapids_singlecell/pertpy_gpu/_distance.py

📝 Walkthrough

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Wasserstein distance metric with GPU-accelerated batched Sinkhorn solver for optimal transport computation
    • Added bootstrap resampling support for distance estimation, computing per-pair mean and variance
  • Bug Fixes

    • Distance calculations now properly drop unused categorical groups from output, preventing NaN values in empty categories

Walkthrough

This PR adds GPU-accelerated Wasserstein distance computation to rapids_singlecell via batched log-domain Sinkhorn optimal transport. It includes CUDA kernels for potential updates, Python solver orchestration with multi-GPU support, metric class implementation with bootstrap variance, Distance API integration, and comprehensive test coverage validating correctness against CPU reference implementations and upstream pertpy.

Changes

Wasserstein Distance Metric Implementation

Layer / File(s) Summary
CUDA log-domain Sinkhorn kernels
src/rapids_singlecell/_cuda/sinkhorn/kernels_sinkhorn.cuh
Implements batched Sinkhorn kernels: log-sum-exp primitives (lse_acc, lse_merge), auto-epsilon computation, column/row potential updates with over-relaxation support, fused pairwise squared-Euclidean cost construction, and per-pair convergence detection via max-change reduction.
CUDA/pybind bindings and module registration
src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu
Wraps all Sinkhorn kernels as Python-callable functions via nanobind: auto_eps, update_g, update_f, check_convergence, build_cost. Extracts tensor metadata, converts stream pointers, computes grid dimensions, and registers the _sinkhorn_cuda extension module.
Python-level Sinkhorn solver orchestration
src/rapids_singlecell/pertpy_gpu/_metrics/_sinkhorn.py
Implements batched solver with state allocation (make_state), asynchronous multi-stream iteration control (_step, run_async), convergence checking, and finalized OT cost computation via deterministic prefix-sum segment accumulation (finalize).
WassersteinMetric class with multi-GPU orchestration
src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py
Implements full WassersteinMetric class: ragged layout construction for contiguous group indexing, memory-aware batch planning, multi-GPU workload distribution with round-robin/cumulative-work split, bootstrap variance via resampling, and public methods (pairwise, onesided_distances, contrast_distances, compute_distance, bootstrap, bootstrap_arrays) wiring GPU solver to AnnData/categorical inputs.
BaseMetric refactoring and Distance integration
src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py, src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py, src/rapids_singlecell/pertpy_gpu/_distance.py
Extracts shared categorical group subsetting logic (_subset_to_groups) into BaseMetric for reuse; removes duplicate from EDistanceMetric. Extends Distance public interface: updated Metric type and SUPPORTED_METRICS to include "wasserstein", added metric-specific **kwargs forwarding through __init__, new "wasserstein" dispatch branch in _initialize_metric.
Comprehensive test suite
tests/pertpy/test_distance_wasserstein.py, tests/pertpy/test_distances.py
Validates GPU solver correctness (vs NumPy reference), auto-epsilon behavior, Distance API (initialization, relaxation bounds, convergence across omega), bootstrap variance (non-negativity, reproducibility, mean closeness), multi-GPU correctness (parity across devices), parity with upstream pertpy, dtype consistency (float32/float64), categorical group filtering, contrast distance computation, and edge cases; adds unused-category dropping test for edistance.
Build system and release documentation
CMakeLists.txt, docs/release-notes/0.15.2.md
Registers _sinkhorn_cuda nanobind module in CUDA build section (conditional on RSC_BUILD_EXTENSIONS); documents wasserstein metric feature and unused-category dropping bug fix for all Distance metrics.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • scverse/rapids-singlecell#676: Extends Distance metric support via the same Metric/SUPPORTED_METRICS dispatch mechanism; this PR adds "wasserstein" while the referenced PR adds pseudobulk metrics.

Suggested reviewers

  • Zethson
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 46.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title "add wasserstein to pertpy GPU Distance" clearly and concisely summarizes the main change—adding Wasserstein metric support to the pertpy GPU Distance implementation.
Description check ✅ Passed The pull request description is relevant to the changeset, mentioning the Wasserstein metric addition and performance benefits, plus an associated bug fix for empty categories in edistance.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch add-wasserstein

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (1)
tests/pertpy/test_distance_wasserstein.py (1)

123-123: ⚡ Quick win

Make paired test inputs fail fast on length drift.

zip(pairs, epsilons) at lines 123 and 147 can silently truncate if the lists ever diverge; use strict=True (repo requires-python is >=3.12) to fail loudly.

♻️ Proposed fix
-    for (X, Y), eps in zip(pairs, epsilons):
+    for (X, Y), eps in zip(pairs, epsilons, strict=True):
@@
-    for b, ((X, Y), eps) in enumerate(zip(pairs, epsilons)):
+    for b, ((X, Y), eps) in enumerate(zip(pairs, epsilons, strict=True)):
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/pertpy/test_distance_wasserstein.py` at line 123, The test loops use
zip(pairs, epsilons) which can silently truncate if lengths diverge; update both
occurrences (the loops iterating "for (X, Y), eps in zip(pairs, epsilons):" and
the similar loop later) to call zip with strict=True so the test will raise
immediately on length mismatch and fail fast.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu`:
- Around line 114-137: The switch on metric currently treats unknown enum values
as distances::SqEuclidean, which can hide ABI/enum drift; update the switch in
sinkhorn.cu to handle distances::SQEUCLIDEAN explicitly by adding a case that
launches sinkhorn::pairwise_cost_kernel with distances::SqEuclidean<T>, and
change the default branch to raise an error (e.g., throw std::invalid_argument
or use CUDF/rapids error macro) indicating an unknown metric id so invalid
values fail loudly; keep the CUDA_CHECK_LAST_ERROR(pairwise_cost_kernel) call
after the switch.

In `@src/rapids_singlecell/pertpy_gpu/_distance.py`:
- Around line 117-120: The docstring for Distance.__init__ incorrectly
advertises forwarding kwargs like epsilon, max_iter, and tol to the Wasserstein
metric; update the text and behavior so only the options accepted by
WassersteinMetric are exposed. Specifically, in Distance.__init__ (and its
docstring) remove/replace mentions of epsilon, max_iter, and tol and only
document/forward the `relaxation` parameter accepted by WassersteinMetric;
optionally add a validation step in Distance.__init__ that raises a clear
TypeError if unexpected kwargs are passed to WassersteinMetric to prevent
confusing errors.

In `@src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py`:
- Around line 465-489: Currently loc_n, loc_m, gl, gr, cidx_l_all and cidx_r_all
are allocated for all units and bootstraps at once causing OOM; instead, move
bootstrap/resampling into the chunked solve so you only materialize samples for
the current chunk. Concretely: stop creating loc_n/loc_m/gl/gr/cidx_* globally;
create a per-chunk RNG seed from the original random_state plus a chunk index
(or derive per-unit seeds deterministically), then inside the chunking loop that
uses chunk = _pair_batch_size(...) call cp.random.default_rng(per_chunk_seed) to
draw only the (chunk_n, chunk_m) samples needed, compute gl/gr and cidx_l/cidx_r
for that chunk, run the Sinkhorn solve, and free those arrays before the next
chunk. Keep references to rng/random_state, loc_n/loc_m, gl/gr,
cidx_l_all/cidx_r_all and chunk/_pair_batch_size in your changes so the reviewer
can find the modified logic.
- Around line 64-70: _batch planning currently reads GPU memory via
cp.cuda.runtime.memGetInfo() on whatever device is active, causing wrong sizing
in multi-GPU runs; update _pair_batch_size to accept a device id (e.g., dev) and
call memGetInfo inside that device's context (use cp.cuda.Device(dev): or
cp.cuda.Device(dev).use()/context manager) so the free/budget is measured on the
target GPU, then return the same computed batch; update callers (notably
_plan_batches and any places that call _pair_batch_size, such as where batches
are planned before switching devices) to pass the correct device id so sizing
uses the intended GPU.
- Around line 440-463: Before calling .max() on n_row/n_col, handle empty
workloads: validate n_bootstrap is positive and raise a ValueError for
non-positive counts, and if n_pairs == 0 (e.g. after filtering or pair_left
empty) return the function's empty outputs immediately instead of proceeding;
specifically add a guard after computing n_pairs and n_bootstrap to raise on
n_bootstrap <= 0 and to return early when n_pairs == 0 so that subsequent uses
of n_row.max(), n_col.max(), and reductions do not run on empty arrays
(affecting variables n_row, n_col, max_n, max_m).

---

Nitpick comments:
In `@tests/pertpy/test_distance_wasserstein.py`:
- Line 123: The test loops use zip(pairs, epsilons) which can silently truncate
if lengths diverge; update both occurrences (the loops iterating "for (X, Y),
eps in zip(pairs, epsilons):" and the similar loop later) to call zip with
strict=True so the test will raise immediately on length mismatch and fail fast.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 357c40cb-9bdb-4624-9974-473148f9ac1b

📥 Commits

Reviewing files that changed from the base of the PR and between 9cd098c and ab0d310.

📒 Files selected for processing (12)
  • CMakeLists.txt
  • docs/release-notes/0.15.2.md
  • src/rapids_singlecell/_cuda/distances/distance_metrics.cuh
  • src/rapids_singlecell/_cuda/sinkhorn/kernels_sinkhorn.cuh
  • src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu
  • src/rapids_singlecell/pertpy_gpu/_distance.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_sinkhorn.py
  • src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py
  • tests/pertpy/test_distance_wasserstein.py
  • tests/pertpy/test_distances.py
💤 Files with no reviewable changes (1)
  • src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py

Comment thread src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu Outdated
Comment thread src/rapids_singlecell/pertpy_gpu/_distance.py Outdated
Comment thread src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py Outdated
Comment thread src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py Outdated
Comment thread src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py Outdated
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jun 1, 2026

Codecov Report

❌ Patch coverage is 94.11765% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 88.77%. Comparing base (6324228) to head (f06943c).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ids_singlecell/pertpy_gpu/_metrics/_wasserstein.py 93.00% 28 Missing ⚠️
...ids_singlecell/pertpy_gpu/_metrics/_base_metric.py 95.23% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #683      +/-   ##
==========================================
+ Coverage   88.46%   88.77%   +0.31%     
==========================================
  Files         101      103       +2     
  Lines        7817     8296     +479     
==========================================
+ Hits         6915     7365     +450     
- Misses        902      931      +29     
Files with missing lines Coverage Δ
src/rapids_singlecell/pertpy_gpu/_distance.py 87.68% <100.00%> (+0.36%) ⬆️
...apids_singlecell/pertpy_gpu/_metrics/_edistance.py 95.05% <ø> (-0.17%) ⬇️
...rapids_singlecell/pertpy_gpu/_metrics/_sinkhorn.py 100.00% <100.00%> (ø)
...ids_singlecell/pertpy_gpu/_metrics/_base_metric.py 91.17% <95.23%> (+1.59%) ⬆️
...ids_singlecell/pertpy_gpu/_metrics/_wasserstein.py 93.00% <93.00%> (ø)

@Intron7
Copy link
Copy Markdown
Member Author

Intron7 commented Jun 2, 2026

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented Jun 2, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@Intron7 Intron7 enabled auto-merge (squash) June 2, 2026 13:51
@Intron7 Intron7 merged commit e484f58 into main Jun 2, 2026
22 of 26 checks passed
@Intron7 Intron7 deleted the add-wasserstein branch June 2, 2026 14:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants