# Train an scVI model using Census data

This notebook demonstrates a scalable approach to training an [scVI](https://docs.scvi-tools.org/en/latest/user_guide/models/scvi.html) model on Census data. The [scvi-tools](https://scvi-tools.org/) library is built around [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/). [TileDB-SOMA-ML](https://github.com/single-cell-data/TileDB-SOMA-ML) assists with streaming Census query results to PyTorch in batches, allowing for training datasets larger than available RAM.

## Contents

1. Training the model
2. Generate cell embeddings
3. Analyzing the results

## Training the model 

Let's start by importing the necessary dependencies.

In [1]:
print("Sta5rt")

Sta5rt


In [2]:
import warnings
from typing import Any, Dict, List
import torch
import cellxgene_census
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import tiledbsoma as soma
from tiledbsoma_ml import SCVIDataModule
import torch
from cellxgene_census.experimental.pp import highly_variable_genes
from lightning import LightningDataModule
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader

import logging
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s %(message)s")
logging.getLogger("tiledbsoma_ml.dataset").setLevel(logging.DEBUG)

import time
from torch.profiler import schedule, ProfilerActivity, tensorboard_trace_handler
from pytorch_lightning.profilers import PyTorchProfiler
import os, torch

logdir = "./log/scvi_prof"                     # one folder per run
os.makedirs(logdir, exist_ok=True)

warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


We'll now prepare the necessary parameters for running a training pass of the model.

For this notebook, we'll use a stable version of the Census:

We'll also do two types of filtering.

For **cells**, we will apply a filter to only select primary cells, with at least 300 expressed genes (nnz >= 300). For notebook demonstration purposes, we will also apply a tissue filtering so that the training can happen on a laptop. The same approach can be used on datasets much larger than available RAM. (A GPU is recommended, though.)

For **genes**, we will apply a filter so that only the top 8000 highly variable genes (HVG) are included in the training. This is a commonly used dimensionality reduction approach and is recommended on production models as well.

Let's define a few parameters:

In [8]:
experiment_name = "mus_musculus"
obs_value_filter = 'is_primary_data == True and tissue_general in ["pancreas", "kidney"] and nnz >= 300'
top_n_hvg = 8000
hvg_batch = ["assay", "suspension_type"]

In [6]:
LOCAL_URI = "/home/ec2-user/new/mm_pan_kidney_subset_soma"       # same path you wrote above
# census = cellxgene_census.open_soma(uri=LOCAL_URI)
exp = soma.open(LOCAL_URI, mode="r")


For HVG, we can use the `highly_variable_genes` function provided in `cellxgene_census`, which can compute HVGs in constant memory:

In [11]:
%%time
query = exp.axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter=obs_value_filter),
)
print(f"{query.n_obs} obs")
query.obs(column_names=['dataset_id']).concat().to_pandas().dataset_id.astype(str).value_counts()
print(exp.ms["RNA"].X.keys()) 

522076 obs
KeysView(<Collection 'file:///home/ec2-user/mm_pan_kidney_subset_soma/ms/RNA/X' (open for 'r') (1 item)
    'data': 'file:///home/ec2-user/mm_pan_kidney_subset_soma/ms/RNA/X/data' (unopened)>)
CPU times: user 202 ms, sys: 226 ms, total: 428 ms
Wall time: 262 ms


In [12]:
%%time
hvgs_df = highly_variable_genes(
    query,
    n_top_genes=top_n_hvg,
    batch_key=hvg_batch,
    layer="data"
)
hv = hvgs_df.highly_variable
hv_idx = hv[hv].index

CPU times: user 2min 52s, sys: 1min 16s, total: 4min 9s
Wall time: 39.7 s


We will now introduce a helper class `SCVIDataModule` to connect TileDB-SOMA-ML with PyTorch Lightning. It subclasses [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) and:

1. Uses TileDB-SOMA-ML to prepare a DataLoader for the results of a SOMA [`ExperimentAxisQuery`](https://tiledbsoma.readthedocs.io/en/1.15.0/python-tiledbsoma-experimentaxisquery.html) on the Census.
1. Derives each cell's scVI batch label as a tuple of obs attributes: `dataset_id`, `assay`, `suspension_type`, `donor_id`.
    * *Don't confuse each cell's label for scVI "batch" integration with a training data "batch" generated by the DataLoader.*
1. Converts the RNA counts and batch labels to a dict of tensors for each training data batch, as scVI expects.

In [18]:
%%time
hvg_query = exp.axis_query(
    measurement_name="RNA",
    obs_query=soma.AxisQuery(value_filter=obs_value_filter),
    var_query=soma.AxisQuery(coords=(list(hv_idx),)),
)

datamodule = SCVIDataModule(
    hvg_query,
    layer_name="data",
    batch_size=1024,
    shuffle=True,
    seed=42,
    dataloader_kwargs={"num_workers": 8, "persistent_workers": False, "pin_memory":True,},
)

(datamodule.n_obs, datamodule.n_vars, datamodule.n_batch)

{'num_workers': 8, 'persistent_workers': False, 'pin_memory': True}
CPU times: user 2.97 s, sys: 501 ms, total: 3.47 s
Wall time: 3.24 s


(522076, 8000, 109)

Most parameters to `SCVIDataModule` are passed through to the [`tiledbsoma_ml.ExperimentDataset`](https://single-cell-data.github.io/TileDB-SOMA-ML/#tiledbsoma_ml.ExperimentDataset) initializer; see that documentation to understand how it can be tuned.

In particular, here are some parameters of interest:

* `shuffle`: shuffles the result cell order, which is often advisable for model training.
* `batch_size`: controls the size (number of cells) in each training data batch, in turn controlling memory usage.
* `dataloader_kwargs`: [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) tuning, for example controlling parallelization.

We can now create the scVI model object:

In [19]:
n_layers = 1
n_latent = 50

model = scvi.model.SCVI(
    n_layers=n_layers,
    n_latent=n_latent,
    gene_likelihood="nb",
    encode_covariates=False,
)

Then, we can invoke the `.train` method which will start the training loop. For this demonstration, we'll only do a single epoch, but this should likely be increased for a production model. The scVI models hosted in CELLxGENE have been trained for 100 epochs.

In [20]:
def make_profiler(logdir):
    """
    Returns a PyTorch-Lightning profiler that
    1) records full CPU + CUDA traces
    2) keeps Python line-level stacks (`with_stack=True`)
    3) writes both TensorBoard traces *and* folded-stack files for flame graphs
    """
    def _trace_ready(p):                       # runs every ACTIVE window
        step = p.step_num
        # p.export_chrome_trace(f"{logdir}/trace_{step}.json")
        p.export_stacks(                      # for flamegraph.pl
            f"{logdir}/stacks_{step}.txt",
            metric="self_cuda_time_total",
        )
        tensorboard_trace_handler(logdir)(p)   # TB plugin

    return PyTorchProfiler(
        activities=[ProfilerActivity.CPU],
        on_trace_ready=_trace_ready,
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    )

pl_profiler = make_profiler(logdir)

In [21]:
t0 = time.perf_counter()

model.train(
    datamodule=datamodule,
    max_epochs=3,
    early_stopping=False,
    devices=-1,
    strategy="ddp_notebook_find_unused_parameters_true",
    profiler=pl_profiler,
)

print("The time: ", time.perf_counter() - t0)

2025-07-07 07:13:18,194 GPU available: True (cuda), used: True
2025-07-07 07:13:18,195 TPU available: False, using: 0 TPU cores
2025-07-07 07:13:18,195 HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
2025-07-07 07:13:18,408 ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
2025-07-07 07:13:18,858 switching torch multiprocessing start method from "fork" to "spawn"


Epoch 1/3:   0%|                                                     | 0/3 [00:00<?, ?it/s]

[rank0]:[W707 07:14:32.685530531 CPUAllocator.cpp:245] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event


We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 0:1024

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3     0.1; t4   297.7;

We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 1024:2048

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3     0.1; t4   176.2;

We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 2048:3072

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3     0.1; t4    86.6;

We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 3072:4096

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    30.5; t4     1.1;

We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 4096:5120

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    25.5; t4    51.2;

We have n_rows= 1024 and shape be 8000
[PID 341868] slice_tonumpy called for rows 5120:6144

withing slice_tonumpy, we have t1     0.0; t2

2025-07-07 07:18:48,379 `Trainer.fit` stopped: `max_epochs=3` reached.


Epoch 3/3: 100%|█| 3/3 [04:52<00:00, 97.60s/it, v_num=1, train_loss_step=2.54e+3, train_los

We have n_rows= 1024 and shape be 8000
[PID 344648] slice_tonumpy called for rows 48128:49152

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    50.1; t4    32.1;

We have n_rows= 1024 and shape be 8000
[PID 344648] slice_tonumpy called for rows 49152:50176

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    26.9; t4    51.0;

We have n_rows= 1024 and shape be 8000
[PID 344648] slice_tonumpy called for rows 50176:51200

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    59.0; t4    29.3;

We have n_rows= 1024 and shape be 8000
[PID 344648] slice_tonumpy called for rows 51200:52224

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    26.6; t4    53.2;

We have n_rows= 1024 and shape be 8000
[PID 344648] slice_tonumpy called for rows 52224:53248

withing slice_tonumpy, we have t1     0.0; t2     0.0; t3    27.3; t4    52.9;

We have n_rows= 1024 and shape b

2025-07-07 07:18:49,971 FIT Profiler Report
Profile stats for: records
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         2.14%     134.951ms        95.99%        6.052s        2.017s           8 b          28 b             3  
[pl][profile][_TrainingEpochLoop].train_dataloader_n...         0.01%     833.902us        89.08%        5.616s        1.872s           0 b           0 b             3  
enumerate(DataLoader)#_MultiProcessingDataLoaderIter...        89.06%        5.

The time:  332.3622770039947


We can now save the trained model. As of the current writing, scvi-tools doesn't support saving a model that wasn't generated through an AnnData loader, so we'll use some custom code:

In [None]:
model_state_dict = model.module.state_dict()
var_names = hv_idx.to_numpy()
user_attributes = model._get_user_attributes()
user_attributes = {a[0]: a[1] for a in user_attributes if a[0][-1] == "_"}

user_attributes.update(
    {
        "n_batch": datamodule.n_batch,
        "n_extra_categorical_covs": 0,
        "n_extra_continuous_covs": 0,
        "n_labels": 1,
        "n_vars": datamodule.n_vars,
        "batch_labels": datamodule.batch_labels,
    }
)

with open("model.pt", "wb") as f:
    torch.save(
        {
            "model_state_dict": model_state_dict,
            "var_names": var_names,
            "attr_dict": user_attributes,
        },
        f,
    )

We will now load the model back and use it to generate cell embeddings (the latent space), which can then be used for further analysis. Loading the model similarly involves some custom code.

In [None]:
with open("model.pt", "rb") as f:
    torch_model = torch.load(f)

    adict = torch_model["attr_dict"]
    params = adict["init_params_"]["non_kwargs"]

    n_batch = adict["n_batch"]
    n_extra_categorical_covs = adict["n_extra_categorical_covs"]
    n_extra_continuous_covs = adict["n_extra_continuous_covs"]
    n_labels = adict["n_labels"]
    n_vars = adict["n_vars"]

    latent_distribution = params["latent_distribution"]
    dispersion = params["dispersion"]
    n_hidden = params["n_hidden"]
    dropout_rate = params["dropout_rate"]
    gene_likelihood = params["gene_likelihood"]

    model = scvi.model.SCVI(
        n_layers=params["n_layers"],
        n_latent=params["n_latent"],
        gene_likelihood=params["gene_likelihood"],
        encode_covariates=False,
    )

    module = model._module_cls(
        n_input=n_vars,
        n_batch=n_batch,
        n_labels=n_labels,
        n_continuous_cov=n_extra_continuous_covs,
        n_cats_per_cov=None,
        n_hidden=n_hidden,
        n_latent=n_latent,
        n_layers=n_layers,
        dropout_rate=dropout_rate,
        dispersion=dispersion,
        gene_likelihood=gene_likelihood,
        latent_distribution=latent_distribution,
    )
    model.module = module

    model.module.load_state_dict(torch_model["model_state_dict"])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model.to_device(device)
    model.module.eval()
    model.is_trained = True

## Generate cell embeddings

We will now generate the cell embeddings for this model, using the `get_latent_representation` function available in scvi-tools. 

We can use another instance of the `SCVIDataModule` for the forward pass, so we don't need to load the whole dataset in memory. This will have shuffling disabled to make it easier to join the embeddings later. We also want to restore the list of scVI batch labels from the training data, ensuring our forward pass will map batch labels to tensors in the expected way (although this specific example would work regardless, since it reuses the same query).

In [None]:
inference_datamodule = SCVIDataModule(
    hvg_query,
    layer_name="raw",
    batch_labels=adict["batch_labels"],
    batch_size=1024,
    shuffle=False,
    dataloader_kwargs={"num_workers": 0, "persistent_workers": False},
)

To feed the data to `get_latent_representation`, we operate `inference_datamodule` as PyTorch Lightning would during training:

In [None]:
inference_datamodule.setup()
inference_dataloader = (
    inference_datamodule.on_before_batch_transfer(batch, None) for batch in inference_datamodule.train_dataloader()
)
latent = model.get_latent_representation(dataloader=inference_dataloader)
latent.shape

We successfully trained the model and generated embeddings using limited memory. Even on the full Census, this has been tested to run with less than 30G of memory.

## Analyzing the results

We will now take a look at the UMAP for the generated embedding. Note that this model was only trained for one epoch (for demo purposes), so we don't expect the UMAP to show significant integration patterns, but it is nonetheless a good way to check the overall health of the generated embedding.

In order to do this, we'll use `scanpy` which accepts an AnnData object, so we'll generate one using the `get_anndata` utility function:

In [None]:
adata = cellxgene_census.get_anndata(
    census,
    organism=experiment_name,
    obs_value_filter=obs_value_filter,
)

Add the generated embedding (stored in `latent`) in the obsm slot of the AnnData object:

In [None]:
# verify cell order:
assert np.array_equal(np.array(adata.obs["soma_joinid"]), inference_datamodule.train_dataset.query_ids.obs_joinids)

adata.obsm["scvi"] = latent

We can now generate the neighbors and the UMAP.

In [None]:
sc.pp.neighbors(adata, use_rep="scvi", key_added="scvi")
sc.tl.umap(adata, neighbors_key="scvi")
sc.pl.umap(adata, color="dataset_id", title="SCVI")

In [None]:
sc.pl.umap(adata, color="tissue_general", title="SCVI")

In [None]:
sc.pl.umap(adata, color="cell_type", title="SCVI")