# **Multi-GPU Showcase with `persist`**  

**Author:** [Severin Dicks](https://github.com/Intron7)  
**Copyright:** [scverse](https://scverse.org)  

## **Overview**  

In this notebook, we showcase the **multi-GPU computation capabilities** of **rapids-singlecell**  
using **Dask**, enabling the analysis of **11 million cells at unprecedented speed**.  

This notebook is run on a **DGX system with 8 NVIDIA H100 GPUs**, demonstrating how **Dask** can efficiently distribute computations across multiple GPUs.  

### **Key Advantages of Multi-GPU Computation**  
By leveraging **Dask and RAPIDS**, we can:  
- **Process massive single-cell datasets** without exceeding memory limits.  
- **Fully utilize all available GPUs**, scaling performance across multiple devices.  
- **Enable chunk-based execution**, efficiently managing memory by loading only necessary data.  

### **Combining Multi-GPU with Out-of-Core Processing**  
- **Multi-GPU processing and out-of-core execution** can be combined to analyze even larger datasets that exceed GPU memory.  
- However, in this notebook, **we focus purely on multi-GPU scaling** without out-of-core execution.  

This approach significantly accelerates **large-scale single-cell analysis**,  
making it feasible on high-performance hardware like **DGX systems**,  
while also being adaptable to multi-GPU workstations.

In [1]:
import dask
import time


from dask_cuda import LocalCUDACluster
from dask.distributed import Client

## **Initializing a Multi-GPU Dask Cluster for RAPIDS**  

To fully utilize all **8 NVIDIA H100 GPUs** on the DGX system,  
we initialize a **multi-GPU Dask cluster** and configure **RAPIDS Memory Manager (RMM)**  
for efficient memory handling across GPUs.  

### **Setting Up Memory Management with RMM**  
RAPIDS **RMM (RAPIDS Memory Manager)** helps optimize GPU memory usage by enabling **managed memory**,  
which improves memory efficiency when working with large-scale datasets.

### **Launching a Multi-GPU Dask Cluster**
We create a Dask CUDA cluster that utilizes all 8 GPUs for preprocessing and analysis.

* `rmm.reinitialize(managed_memory=True)` → Enables unified memory management,
allowing GPUs to handle memory allocation more dynamically.
* `cp.cuda.set_allocator(rmm_cupy_allocator)` → Sets CuPy to use the RAPIDS memory allocator.
* `LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus)`→
Launches a Dask cluster across all available GPUs.
* `client.run(set_mem)` → Configures memory management on all worker nodes.

In [2]:
%%time
import rmm
import cupy as cp

from rmm.allocators.cupy import rmm_cupy_allocator

def set_mem():
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)

preprocessing_gpus="0,1,2,3,4,5,6,7"
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus)

client = Client(cluster)

client.run(set_mem)

client

CPU times: user 668 ms, sys: 241 ms, total: 909 ms
Wall time: 19.2 s


0,1
Connection method: Cluster object,Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 8,Total memory: 1.73 TiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:37877,Workers: 8
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 1.73 TiB

0,1
Comm: tcp://127.0.0.1:39315,Total threads: 1
Dashboard: http://127.0.0.1:36403/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:33591,
Local directory: /tmp/dask-scratch-space/worker-8wd487y_,Local directory: /tmp/dask-scratch-space/worker-8wd487y_

0,1
Comm: tcp://127.0.0.1:35355,Total threads: 1
Dashboard: http://127.0.0.1:37101/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:33049,
Local directory: /tmp/dask-scratch-space/worker-1g1cx575,Local directory: /tmp/dask-scratch-space/worker-1g1cx575

0,1
Comm: tcp://127.0.0.1:45139,Total threads: 1
Dashboard: http://127.0.0.1:35521/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:46389,
Local directory: /tmp/dask-scratch-space/worker-5duq3dj6,Local directory: /tmp/dask-scratch-space/worker-5duq3dj6

0,1
Comm: tcp://127.0.0.1:42879,Total threads: 1
Dashboard: http://127.0.0.1:46855/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:42385,
Local directory: /tmp/dask-scratch-space/worker-snrs8tx1,Local directory: /tmp/dask-scratch-space/worker-snrs8tx1

0,1
Comm: tcp://127.0.0.1:42695,Total threads: 1
Dashboard: http://127.0.0.1:38663/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:33393,
Local directory: /tmp/dask-scratch-space/worker-z5z5peop,Local directory: /tmp/dask-scratch-space/worker-z5z5peop

0,1
Comm: tcp://127.0.0.1:34455,Total threads: 1
Dashboard: http://127.0.0.1:46669/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:34687,
Local directory: /tmp/dask-scratch-space/worker-g0a_p1dw,Local directory: /tmp/dask-scratch-space/worker-g0a_p1dw

0,1
Comm: tcp://127.0.0.1:44877,Total threads: 1
Dashboard: http://127.0.0.1:33797/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:34747,
Local directory: /tmp/dask-scratch-space/worker-pzuj8o8y,Local directory: /tmp/dask-scratch-space/worker-pzuj8o8y

0,1
Comm: tcp://127.0.0.1:41995,Total threads: 1
Dashboard: http://127.0.0.1:33305/status,Memory: 221.46 GiB
Nanny: tcp://127.0.0.1:39493,
Local directory: /tmp/dask-scratch-space/worker-_exwdhdh,Local directory: /tmp/dask-scratch-space/worker-_exwdhdh


In [3]:
import rapids_singlecell as rsc
import anndata as ad

  from .autonotebook import tqdm as notebook_tqdm


## **Loading Large Datasets into AnnData with Dask**  

To efficiently handle large-scale single-cell datasets, we load data directly from an **HDF5 (`h5`) or Zarr file**  
into an **AnnData object** using **Dask arrays**. This enables **lazy loading**, allowing data to be processed in chunks  
without exceeding memory limits.  

We achieve this using **`read_elem_as_dask`**, which loads the expression matrix (`X`) as a **Dask array**

In [4]:
from anndata.experimental import read_elem_as_dask

import h5py

SPARSE_CHUNK_SIZE = 200_000
data_pth = "./h5/cell_atlas.h5ad"


f = h5py.File(data_pth)
X = f["X"]
shape = X.attrs["shape"]
adata = ad.AnnData(
    X = read_elem_as_dask(X, (SPARSE_CHUNK_SIZE, shape[1])),
    obs = ad.io.read_elem(f["obs"]),
    var = ad.io.read_elem(f["var"]))
f.close()



## **Transferring AnnData to GPU and Persisting Data**  

To leverage **multi-GPU acceleration**, we transfer the AnnData object to GPU memory  
and persist its **Dask-backed expression matrix** for efficient computation. 

**Step-by-Step Breakdown:**
1. Move AnnData to GPU → `rsc.get.anndata_to_GPU(adata)`
    * Transfers all numerical data (`.X`) to GPU memory.
2. Persist the Expression Matrix → `adata.X = adata.X.persist()`
    * Keeps `adata.X` in memory across Dask workers, avoiding redundant recomputation.
3. Optimize Chunking → `adata.X.compute_chunk_sizes()`
    * Computes the exact chunk sizes for optimal Dask scheduling and memory usage.

In [5]:
rsc.get.anndata_to_GPU(adata)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

Unnamed: 0,Array,Chunk
Shape,"(11441407, 45854)","(200000, 45854)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441407, 45854) (200000, 45854) Dask graph 58 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",45854  11441407,

Unnamed: 0,Array,Chunk
Shape,"(11441407, 45854)","(200000, 45854)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Quality Control (QC) Metrics Calculation**  

Before proceeding with further analysis, we compute **quality control (QC) metrics**  
to assess dataset quality and filter out low-quality cells or genes.  

We use **`rsc.pp.calculate_qc_metrics()`** to calculate key QC metrics

Although we are working with Dask-backed AnnData, this operation requires a synchronization step.
This means that Dask computations must be evaluated immediately,
so the process is not completely lazy like other out-of-core operations.

In [6]:
%%time
rsc.pp.calculate_qc_metrics(adata)

CPU times: user 313 ms, sys: 134 ms, total: 448 ms
Wall time: 519 ms


## **Filtering Cells and Genes Without Additional Computation**  

Instead of using **`sc.pp.filter_cells`** and **`sc.pp.filter_genes`**,  
we apply filtering directly using boolean indexing to **avoid extra computation**.

**Why Use Direct Indexing Instead of Built-in Functions?**
* More Efficient with Dask → Avoids triggering additional computations.
* Preserves Lazy Execution → Filtering is applied without forcing full dataset evaluation.
* Copy is Essential → Using `.copy()` prevents views, which may not work reliably with Dask-backed AnnData.

In [7]:
adata = adata[(adata.obs["n_genes_by_counts"]<=10000) 
            & (adata.obs["n_genes_by_counts"]>=200)].copy()
adata = adata[:,adata.var["n_cells_by_counts"]>=10].copy()

## **Persisting and Optimizing Chunk Sizes After QC and Subsetting**  

After performing **quality control (QC) and subsetting** the dataset,  
we persist the **Dask-backed expression matrix** and optimize its chunk sizes for efficient multi-GPU execution.  
Persisting after filtering ensures that only high-quality, relevant data remains in memory.

In [8]:
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(197265, 41291)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 41291) (197265, 41291) Dask graph 58 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",41291  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(197265, 41291)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Log Normalization (Fully Lazy Execution)**  

Next, we apply **log normalization** to scale gene expression values.  
This step ensures that differences in sequencing depth across cells do not dominate downstream analysis.  

In [9]:
%%time
rsc.pp.normalize_total(adata,target_sum = 10000)
rsc.pp.log1p(adata)

CPU times: user 9.6 ms, sys: 1.16 ms, total: 10.8 ms
Wall time: 9.59 ms


## **Storing the Processed Data in Memory**  

After Log Normalization, we persist the updated expression matrix  
to store the new results in memory for efficient access.  

In [10]:
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(197265, 41291)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 41291) (197265, 41291) Dask graph 58 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",41291  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(197265, 41291)"
Dask graph,58 chunks in 1 graph layer,58 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Selecting Highly Variable Genes**  

To focus on the most informative features, we identify **highly variable genes (HVGs)**  
using the **Cell Ranger** method and subset the dataset accordingly.  

* Copy is Essential → Using `.copy()` prevents views, ensuring the operation works properly with Dask-backed AnnData.

In [11]:
%%time
rsc.pp.highly_variable_genes(adata,n_top_genes=5000, flavor="cell_ranger")
adata = adata[:,adata.var.highly_variable].copy()

CPU times: user 36.2 s, sys: 933 ms, total: 37.2 s
Wall time: 36.6 s


## **Rechunking the Expression Matrix for Multi-GPU Execution**  

To optimize performance across **8 GPUs**, we rechunk the expression matrix (`adata.X`)  
so that each GPU processes an equal portion of the dataset.  

In [12]:
n_rows = adata.shape[0]
n_cols = adata.shape[1]
cols_per_worker = (n_rows +8-1 )//8
adata.X = adata.X.rechunk((cols_per_worker, n_cols)).persist()
adata.X.compute_chunk_sizes()

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1430156, 5000)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 5000) (1430156, 5000) Dask graph 8 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",5000  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1430156, 5000)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Scaling Gene Expression (Requires Synchronization)**  

To standardize gene expression values, we apply **feature scaling**,  
We also `persist` the results to ensure fast accessibility

In [13]:
%%time
rsc.pp.scale(adata, zero_center= False)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 286 ms, sys: 232 ms, total: 518 ms
Wall time: 4.05 s


Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1430156, 5000)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 5000) (1430156, 5000) Dask graph 8 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",5000  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1430156, 5000)"
Dask graph,8 chunks in 1 graph layer,8 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Principal Component Analysis (PCA) on GPU**  

To reduce dimensionality while preserving meaningful variation,  
we perform **Principal Component Analysis (PCA)** using **GPU acceleration**.

Finalizing the Transformation with `.compute()`
    * After computing the principal components, the data remains lazy (Dask CuPy array).
    * Calling `.compute()` on `adata.obsm["X_pca"]` performs the final transformation,
      projecting the data onto the computed PCs and materializing the result as a fully computed CuPy array.

In [14]:
%%time
rsc.pp.pca(adata, n_comps = 100,mask_var=None)
adata.obsm["X_pca"]=adata.obsm["X_pca"].compute()

CPU times: user 1.74 s, sys: 3.79 s, total: 5.53 s
Wall time: 9.16 s
