add wasserstein to pertpy GPU Distance#683
Conversation
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (7)
✅ Files skipped from review due to trivial changes (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughSummary by CodeRabbitRelease Notes
WalkthroughThis PR adds GPU-accelerated Wasserstein distance computation to rapids_singlecell via batched log-domain Sinkhorn optimal transport. It includes CUDA kernels for potential updates, Python solver orchestration with multi-GPU support, metric class implementation with bootstrap variance, Distance API integration, and comprehensive test coverage validating correctness against CPU reference implementations and upstream pertpy. ChangesWasserstein Distance Metric Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
tests/pertpy/test_distance_wasserstein.py (1)
123-123: ⚡ Quick winMake paired test inputs fail fast on length drift.
zip(pairs, epsilons)at lines 123 and 147 can silently truncate if the lists ever diverge; usestrict=True(repo requires-python is>=3.12) to fail loudly.♻️ Proposed fix
- for (X, Y), eps in zip(pairs, epsilons): + for (X, Y), eps in zip(pairs, epsilons, strict=True): @@ - for b, ((X, Y), eps) in enumerate(zip(pairs, epsilons)): + for b, ((X, Y), eps) in enumerate(zip(pairs, epsilons, strict=True)):🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/pertpy/test_distance_wasserstein.py` at line 123, The test loops use zip(pairs, epsilons) which can silently truncate if lengths diverge; update both occurrences (the loops iterating "for (X, Y), eps in zip(pairs, epsilons):" and the similar loop later) to call zip with strict=True so the test will raise immediately on length mismatch and fail fast.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cu`:
- Around line 114-137: The switch on metric currently treats unknown enum values
as distances::SqEuclidean, which can hide ABI/enum drift; update the switch in
sinkhorn.cu to handle distances::SQEUCLIDEAN explicitly by adding a case that
launches sinkhorn::pairwise_cost_kernel with distances::SqEuclidean<T>, and
change the default branch to raise an error (e.g., throw std::invalid_argument
or use CUDF/rapids error macro) indicating an unknown metric id so invalid
values fail loudly; keep the CUDA_CHECK_LAST_ERROR(pairwise_cost_kernel) call
after the switch.
In `@src/rapids_singlecell/pertpy_gpu/_distance.py`:
- Around line 117-120: The docstring for Distance.__init__ incorrectly
advertises forwarding kwargs like epsilon, max_iter, and tol to the Wasserstein
metric; update the text and behavior so only the options accepted by
WassersteinMetric are exposed. Specifically, in Distance.__init__ (and its
docstring) remove/replace mentions of epsilon, max_iter, and tol and only
document/forward the `relaxation` parameter accepted by WassersteinMetric;
optionally add a validation step in Distance.__init__ that raises a clear
TypeError if unexpected kwargs are passed to WassersteinMetric to prevent
confusing errors.
In `@src/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.py`:
- Around line 465-489: Currently loc_n, loc_m, gl, gr, cidx_l_all and cidx_r_all
are allocated for all units and bootstraps at once causing OOM; instead, move
bootstrap/resampling into the chunked solve so you only materialize samples for
the current chunk. Concretely: stop creating loc_n/loc_m/gl/gr/cidx_* globally;
create a per-chunk RNG seed from the original random_state plus a chunk index
(or derive per-unit seeds deterministically), then inside the chunking loop that
uses chunk = _pair_batch_size(...) call cp.random.default_rng(per_chunk_seed) to
draw only the (chunk_n, chunk_m) samples needed, compute gl/gr and cidx_l/cidx_r
for that chunk, run the Sinkhorn solve, and free those arrays before the next
chunk. Keep references to rng/random_state, loc_n/loc_m, gl/gr,
cidx_l_all/cidx_r_all and chunk/_pair_batch_size in your changes so the reviewer
can find the modified logic.
- Around line 64-70: _batch planning currently reads GPU memory via
cp.cuda.runtime.memGetInfo() on whatever device is active, causing wrong sizing
in multi-GPU runs; update _pair_batch_size to accept a device id (e.g., dev) and
call memGetInfo inside that device's context (use cp.cuda.Device(dev): or
cp.cuda.Device(dev).use()/context manager) so the free/budget is measured on the
target GPU, then return the same computed batch; update callers (notably
_plan_batches and any places that call _pair_batch_size, such as where batches
are planned before switching devices) to pass the correct device id so sizing
uses the intended GPU.
- Around line 440-463: Before calling .max() on n_row/n_col, handle empty
workloads: validate n_bootstrap is positive and raise a ValueError for
non-positive counts, and if n_pairs == 0 (e.g. after filtering or pair_left
empty) return the function's empty outputs immediately instead of proceeding;
specifically add a guard after computing n_pairs and n_bootstrap to raise on
n_bootstrap <= 0 and to return early when n_pairs == 0 so that subsequent uses
of n_row.max(), n_col.max(), and reductions do not run on empty arrays
(affecting variables n_row, n_col, max_n, max_m).
---
Nitpick comments:
In `@tests/pertpy/test_distance_wasserstein.py`:
- Line 123: The test loops use zip(pairs, epsilons) which can silently truncate
if lengths diverge; update both occurrences (the loops iterating "for (X, Y),
eps in zip(pairs, epsilons):" and the similar loop later) to call zip with
strict=True so the test will raise immediately on length mismatch and fail fast.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 357c40cb-9bdb-4624-9974-473148f9ac1b
📒 Files selected for processing (12)
CMakeLists.txtdocs/release-notes/0.15.2.mdsrc/rapids_singlecell/_cuda/distances/distance_metrics.cuhsrc/rapids_singlecell/_cuda/sinkhorn/kernels_sinkhorn.cuhsrc/rapids_singlecell/_cuda/sinkhorn/sinkhorn.cusrc/rapids_singlecell/pertpy_gpu/_distance.pysrc/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.pysrc/rapids_singlecell/pertpy_gpu/_metrics/_edistance.pysrc/rapids_singlecell/pertpy_gpu/_metrics/_sinkhorn.pysrc/rapids_singlecell/pertpy_gpu/_metrics/_wasserstein.pytests/pertpy/test_distance_wasserstein.pytests/pertpy/test_distances.py
💤 Files with no reviewable changes (1)
- src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #683 +/- ##
==========================================
+ Coverage 88.46% 88.77% +0.31%
==========================================
Files 101 103 +2
Lines 7817 8296 +479
==========================================
+ Hits 6915 7365 +450
- Misses 902 931 +29
|
|
@coderabbitai review |
✅ Actions performedReview triggered.
|
This adds the Wasserstein metric to pertpy-GPU Distance. For larger Dataset this is up to 400x faster than Pertpy using Jax-GPU. It also fixed an issue with empty categories for edistance