# **Multi-GPU Showcase with `persist`**  

**Author:** [Severin Dicks](https://github.com/Intron7)
**Copyright:** [scverse](https://scverse.org)

## **Overview**  

In this notebook, we showcase the **multi-GPU computation capabilities** of **rapids-singlecell**  
using **Dask**, enabling the analysis of **11 million cells at unprecedented speed**.  

This notebook is run on a **DGX system with 8 NVIDIA H100 GPUs**, demonstrating how **Dask** can efficiently distribute computations across multiple GPUs.  

### **Key Advantages of Multi-GPU Computation**  
By leveraging **Dask and RAPIDS**, we can:  
- **Process massive single-cell datasets** without exceeding memory limits.  
- **Fully utilize all available GPUs**, scaling performance across multiple devices.  
- **Enable chunk-based execution**, efficiently managing memory by loading only necessary data.  

### **Combining Multi-GPU with Out-of-Core Processing**  
- **Multi-GPU processing and out-of-core execution** can be combined to analyze even larger datasets that exceed GPU memory.  
- However, in this notebook, **we focus purely on multi-GPU scaling** without out-of-core execution.  

This approach significantly accelerates **large-scale single-cell analysis**,  
making it feasible on high-performance hardware like **DGX systems**,  
while also being adaptable to multi-GPU workstations.

In [1]:
import dask
import time
import gc

from dask_cuda import LocalCUDACluster
from dask.distributed import Client

## **Initializing a Multi-GPU Dask Cluster for RAPIDS**  

To fully utilize all **8 NVIDIA H100 GPUs** on the DGX system,  
we initialize a **multi-GPU Dask cluster** and configure **RAPIDS Memory Manager (RMM)**  
for efficient memory handling across GPUs.  

### **Setting Up Memory Management with RMM**  
RAPIDS **RMM (RAPIDS Memory Manager)** helps optimize GPU memory usage by enabling **managed memory**,  
which improves memory efficiency when working with large-scale datasets.

### **Launching a Multi-GPU Dask Cluster**
We create a Dask CUDA cluster that utilizes all 8 GPUs for preprocessing and analysis.


### Additional parameters for `LocalCUDACluster`

- **CUDA_VISIBLE_DEVICES=preprocessing_gpus**: selects GPUs to use (e.g., `"0,1,2,3,4,5,6,7"`).
- **threads_per_worker=10**: CPU threads per GPU worker; tune for your workload and I/O.
- **protocol="ucx"**: enables UCX for high-throughput GPU-aware communication (NVLink/InfiniBand/RDMA).
- **rmm_pool_size="10GB"**: initial per-worker RAPIDS Memory Manager (RMM) pool; reduces allocation overhead.
- **rmm_maximum_pool_size="110GB"**: maximum pool growth per worker; allows RMM to expand up to this cap.
- **rmm_allocator_external_lib_list="cupy"**: integrates CuPy with RMM so CuPy allocations come from the pool.
- **Client(cluster)**: attaches the Dask client to the cluster (dashboard link available when running).


In [None]:
%%time
preprocessing_gpus="0,1,2,3,4,5,6,7"
cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus,
                           threads_per_worker=10,
                           protocol="ucx",
                           rmm_pool_size= "10GB",
                           rmm_maximum_pool_size = "110GB",
                           rmm_allocator_external_lib_list= "cupy",
                          )

client = Client(cluster)

client

CPU times: user 13.1 s, sys: 6.41 s, total: 19.5 s
Wall time: 23.2 s


0,1
Connection method: Cluster object,Cluster type: dask_cuda.LocalCUDACluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 8
Total threads: 80,Total memory: 1.97 TiB
Status: running,Using processes: True

0,1
Comm: ucx://127.0.0.1:38750,Workers: 0
Dashboard: http://127.0.0.1:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B

0,1
Comm: ucx://127.0.0.1:44105,Total threads: 10
Dashboard: http://127.0.0.1:43457/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:34725,
Local directory: /tmp/dask-scratch-space/worker-xgcxn_n_,Local directory: /tmp/dask-scratch-space/worker-xgcxn_n_

0,1
Comm: ucx://127.0.0.1:44373,Total threads: 10
Dashboard: http://127.0.0.1:43563/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:39011,
Local directory: /tmp/dask-scratch-space/worker-j09_qkx3,Local directory: /tmp/dask-scratch-space/worker-j09_qkx3

0,1
Comm: ucx://127.0.0.1:48916,Total threads: 10
Dashboard: http://127.0.0.1:36667/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:38812,
Local directory: /tmp/dask-scratch-space/worker-txmrjly6,Local directory: /tmp/dask-scratch-space/worker-txmrjly6

0,1
Comm: ucx://127.0.0.1:40934,Total threads: 10
Dashboard: http://127.0.0.1:45939/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:45970,
Local directory: /tmp/dask-scratch-space/worker-pz7kncsd,Local directory: /tmp/dask-scratch-space/worker-pz7kncsd

0,1
Comm: ucx://127.0.0.1:49821,Total threads: 10
Dashboard: http://127.0.0.1:39573/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:34017,
Local directory: /tmp/dask-scratch-space/worker-n5nr2jvd,Local directory: /tmp/dask-scratch-space/worker-n5nr2jvd

0,1
Comm: ucx://127.0.0.1:59782,Total threads: 10
Dashboard: http://127.0.0.1:43951/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:45444,
Local directory: /tmp/dask-scratch-space/worker-x1wzlvhz,Local directory: /tmp/dask-scratch-space/worker-x1wzlvhz

0,1
Comm: ucx://127.0.0.1:58712,Total threads: 10
Dashboard: http://127.0.0.1:44831/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:44922,
Local directory: /tmp/dask-scratch-space/worker-73mzc4jd,Local directory: /tmp/dask-scratch-space/worker-73mzc4jd

0,1
Comm: ucx://127.0.0.1:36829,Total threads: 10
Dashboard: http://127.0.0.1:43727/status,Memory: 251.81 GiB
Nanny: ucx://127.0.0.1:56777,
Local directory: /tmp/dask-scratch-space/worker-7fna67s9,Local directory: /tmp/dask-scratch-space/worker-7fna67s9


In [3]:
import rapids_singlecell as rsc
import anndata as ad

## **Loading Large Datasets into AnnData with Dask**  

To efficiently handle large-scale single-cell datasets, we load data directly from an **HDF5 (`h5`) or Zarr file**  
into an **AnnData object** using **Dask arrays**. This enables **lazy loading**, allowing data to be processed in chunks  
without exceeding memory limits.  

We achieve this using **`read_elem_as_dask`**, which loads the expression matrix (`X`) as a **Dask array**

In [4]:
from packaging.version import parse as parse_version

if parse_version(ad.__version__) < parse_version("0.12.0rc1"):
    from anndata.experimental import read_elem_as_dask as read_dask
else:
    from anndata.experimental import read_elem_lazy as read_dask
import zarr

SPARSE_CHUNK_SIZE = 50_000
data_pth = "/home/scratch.sdicks_gpu/git/rapids_singlecell-notebooks/zarr/cell_atlas.zarr" #11Million Cells
#data_pth = "zarr/nvidia_1.3M.zarr" #1.3Million Cells

f = zarr.open(data_pth)
X = f["X"]
shape = X.attrs["shape"]
adata = ad.AnnData(
    X = read_dask(X, (SPARSE_CHUNK_SIZE, shape[1])),
    obs = ad.io.read_elem(f["obs"]),
    var = ad.io.read_elem(f["var"])
)


  if parse_version(ad.__version__) < parse_version("0.12.0rc1"):


## **Transferring AnnData to GPU and Persisting Data**  

To leverage **multi-GPU acceleration**, we transfer the AnnData object to GPU memory  
and persist its **Dask-backed expression matrix** for efficient computation. 

**Step-by-Step Breakdown:**
1. Move AnnData to GPU → `rsc.get.anndata_to_GPU(adata)`
    * Transfers all numerical data (`.X`) to GPU memory.
2. Persist the Expression Matrix → `adata.X = adata.X.persist()`
    * Keeps `adata.X` in memory across Dask workers, avoiding redundant recomputation.
3. Optimize Chunking → `adata.X.compute_chunk_sizes()`
    * Computes the exact chunk sizes for optimal Dask scheduling and memory usage.

In [5]:
%%time
rsc.get.anndata_to_GPU(adata)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

CPU times: user 21 s, sys: 974 ms, total: 21.9 s
Wall time: 21.8 s


Unnamed: 0,Array,Chunk
Shape,"(11441407, 45854)","(50000, 45854)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441407, 45854) (50000, 45854) Dask graph 229 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",45854  11441407,

Unnamed: 0,Array,Chunk
Shape,"(11441407, 45854)","(50000, 45854)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Quality Control (QC) Metrics Calculation**  

Before proceeding with further analysis, we compute **quality control (QC) metrics**  
to assess dataset quality and filter out low-quality cells or genes.  

We use **`rsc.pp.calculate_qc_metrics()`** to calculate key QC metrics

Although we are working with Dask-backed AnnData, this operation requires a synchronization step.
This means that Dask computations must be evaluated immediately,
so the process is not completely lazy like other out-of-core operations.

In [6]:
t1 = time.time()

In [7]:
%%time
rsc.pp.calculate_qc_metrics(adata)

CPU times: user 6.97 s, sys: 1.03 s, total: 8 s
Wall time: 8.07 s


## **Filtering Cells and Genes Without Additional Computation**  

Instead of using **`sc.pp.filter_cells`** and **`sc.pp.filter_genes`**,  
we apply filtering directly using boolean indexing to **avoid extra computation**.

**Why Use Direct Indexing Instead of Built-in Functions?**
* More Efficient with Dask → Avoids triggering additional computations.
* Preserves Lazy Execution → Filtering is applied without forcing full dataset evaluation.
* Copy is Essential → Using `.copy()` prevents views, which may not work reliably with Dask-backed AnnData.

In [8]:
%%time
adata = adata[(adata.obs["n_genes_by_counts"]<=10000) 
            & (adata.obs["n_genes_by_counts"]>=200)].copy()
adata = adata[:,adata.var["n_cells_by_counts"]>=10].copy()

CPU times: user 44.9 s, sys: 6.16 s, total: 51.1 s
Wall time: 29.7 s


## **Persisting and Optimizing Chunk Sizes After QC and Subsetting**  

After performing **quality control (QC) and subsetting** the dataset,  
we persist the **Dask-backed expression matrix** and optimize its chunk sizes for efficient multi-GPU execution.  
Persisting after filtering ensures that only high-quality, relevant data remains in memory.

In [9]:
%%time
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 13.9 s, sys: 592 ms, total: 14.4 s
Wall time: 15.3 s


Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(49962, 41291)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 41291) (49962, 41291) Dask graph 229 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",41291  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(49962, 41291)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Log Normalization (Fully Lazy Execution)**  

Next, we apply **log normalization** to scale gene expression values.  
This step ensures that differences in sequencing depth across cells do not dominate downstream analysis.  

In [10]:
gc.collect()  

1503

In [11]:
%%time
rsc.pp.normalize_total(adata,target_sum = 10000)
rsc.pp.log1p(adata)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

CPU times: user 414 ms, sys: 18 ms, total: 432 ms
Wall time: 438 ms


Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(49962, 41291)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 41291) (49962, 41291) Dask graph 229 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",41291  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 41291)","(49962, 41291)"
Dask graph,229 chunks in 1 graph layer,229 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Storing the Processed Data in Memory**  

After Log Normalization, we persist the updated expression matrix  
to store the new results in memory for efficient access.  

## **Selecting Highly Variable Genes**  

To focus on the most informative features, we identify **highly variable genes (HVGs)**  
using the **Cell Ranger** method and subset the dataset accordingly.  

* Copy is Essential → Using `.copy()` prevents views, ensuring the operation works properly with Dask-backed AnnData.

In [12]:
%%time
rsc.pp.highly_variable_genes(adata,n_top_genes=5000, flavor="cell_ranger")

CPU times: user 986 ms, sys: 334 ms, total: 1.32 s
Wall time: 1.34 s


In [13]:
%%time
adata = adata[:,adata.var.highly_variable].copy()

CPU times: user 20.7 s, sys: 641 ms, total: 21.3 s
Wall time: 11.5 s


## **Rechunking the Expression Matrix for Multi-GPU Execution**  

To optimize performance across **8 GPUs**, we rechunk the expression matrix (`adata.X`)  
so that each GPU processes an equal portion of the dataset.  

In [14]:
n_rows = adata.shape[0]
n_cols = adata.shape[1]
rows_per_worker = (n_rows+7-1)//7
adata.X = adata.X.rechunk((rows_per_worker, n_cols)).persist()

adata.X.compute_chunk_sizes()

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1634464, 5000)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 5000) (1634464, 5000) Dask graph 7 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",5000  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1634464, 5000)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Scaling Gene Expression (Requires Synchronization)**  

To standardize gene expression values, we apply **feature scaling**,  
We also `persist` the results to ensure fast accessibility

In [15]:
%%time
rsc.pp.scale(adata, zero_center= False)
adata.X = adata.X.persist()
adata.X.compute_chunk_sizes()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.


CPU times: user 691 ms, sys: 237 ms, total: 927 ms
Wall time: 885 ms


Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1634464, 5000)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix
"Array Chunk Shape (11441244, 5000) (1634464, 5000) Dask graph 7 chunks in 1 graph layer Data type float32 cupyx.scipy.sparse._csr.csr_matrix",5000  11441244,

Unnamed: 0,Array,Chunk
Shape,"(11441244, 5000)","(1634464, 5000)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupyx.scipy.sparse._csr.csr_matrix,float32 cupyx.scipy.sparse._csr.csr_matrix


## **Principal Component Analysis (PCA) on GPU**  

To reduce dimensionality while preserving meaningful variation,  
we perform **Principal Component Analysis (PCA)** using **GPU acceleration**.

Finalizing the Transformation with `.compute()`
    * After computing the principal components, the data remains lazy (Dask CuPy array).
    * Calling `.compute()` on `adata.obsm["X_pca"]` performs the final transformation,
      projecting the data onto the computed PCs and materializing the result as a fully computed CuPy array.

In [16]:
%%time
rsc.pp.pca(adata, n_comps = 100,mask_var=None)
adata.obsm["X_pca"]=adata.obsm["X_pca"].persist()
adata.obsm["X_pca"].compute_chunk_sizes()

CPU times: user 1.23 s, sys: 458 ms, total: 1.69 s
Wall time: 1.44 s


Unnamed: 0,Array,Chunk
Bytes,4.26 GiB,623.50 MiB
Shape,"(11441244, 100)","(1634464, 100)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupy.ndarray,float32 cupy.ndarray
"Array Chunk Bytes 4.26 GiB 623.50 MiB Shape (11441244, 100) (1634464, 100) Dask graph 7 chunks in 1 graph layer Data type float32 cupy.ndarray",100  11441244,

Unnamed: 0,Array,Chunk
Bytes,4.26 GiB,623.50 MiB
Shape,"(11441244, 100)","(1634464, 100)"
Dask graph,7 chunks in 1 graph layer,7 chunks in 1 graph layer
Data type,float32 cupy.ndarray,float32 cupy.ndarray


In [17]:
print("Total Time",time.time()-t1)

Total Time 72.30928540229797


In [18]:
%%time
adata.obsm["X_pca"]=adata.obsm["X_pca"].compute()

CPU times: user 6.4 s, sys: 385 ms, total: 6.78 s
Wall time: 6.82 s


In [22]:
%%time
rsc.pp.neighbors(adata, n_neighbors=15, n_pcs=50, algorithm="mg_ivfflat")

CPU times: user 1min 9s, sys: 13 s, total: 1min 22s
Wall time: 1min 26s


In [20]:
%%time
rsc.tl.umap(adata, min_dist=0.3)

CPU times: user 12 s, sys: 1.42 s, total: 13.4 s
Wall time: 11.4 s


In [21]:
%%time
rsc.tl.leiden(adata, resolution=1.0)

CPU times: user 19.3 s, sys: 4.23 s, total: 23.6 s
Wall time: 19.1 s
