# veloVAE (VAE) benchmark on cell cycle data

Notebook benchmarks velocity, latent time inference, and cross boundary correctness using veloVAE (VAE) on cell cycle data.

In [1]:
import velovae as vv

import numpy as np
import pandas as pd
import torch

import anndata as ad
import scvelo as scv
from cellrank.kernels import VelocityKernel

from rgv_tools import DATA_DIR
from rgv_tools.benchmarking import get_time_correlation



## General settings

In [2]:
scv.settings.verbosity = 3

## Constants

In [3]:
torch.manual_seed(0)
np.random.seed(0)

In [4]:
DATASET = "cell_cycle"

In [5]:
STATE_TRANSITIONS = [("G1", "S"), ("S", "G2M")]

In [6]:
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
    (DATA_DIR / DATASET / "processed" / "velovae_vae").mkdir(parents=True, exist_ok=True)

## Data loading

In [7]:
adata = ad.io.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_processed.h5ad")
scv.pp.moments(adata, n_pcs=None, n_neighbors=None)  ## reconstruct Mu and Ms due to veloVAE run on continues space
adata

computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)


AnnData object with n_obs × n_vars = 1146 × 395
    obs: 'phase', 'fucci_time', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'
    var: 'ensum_id', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_gamma', 'velocity_qreg_ratio', 'velocity_r2', 'velocity_genes'
    uns: 'log1p', 'neighbors', 'pca', 'umap', 'velocity_params'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs', 'true_skeleton'
    layers: 'Ms', 'Mu', 'spliced', 'total', 'unspliced', 'velocity'
    obsp: 'connectivities', 'distances'

## Velocity pipeline

In [8]:
vae = vv.VAE(adata, tmax=20, dim_z=5, device="cuda:0")
config = {}
vae.train(adata, config=config, plot=False, embed="pca")

if SAVE_DATA:
    vae.save_model(DATA_DIR / DATASET / "processed" / "velovae_vae", "encoder_vae", "decoder_vae")
    vae.save_anndata(adata, "vae", DATA_DIR / DATASET / "processed" / "velovae_vae", file_name="velovae.h5ad")

Estimating ODE parameters...


  0%|          | 0/395 [00:00<?, ?it/s]

Detected 313 velocity genes.
Estimating the variance...


  0%|          | 0/395 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/395 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.55, 0.7640522104081641), (0.45, 0.2995065388021425)
(0.44, 0.7271833832917461), (0.56, 0.2872846055146947)
KS-test result: [0. 0. 1.]
Initial induction: 227, repression: 168/395
Learning Rate based on Data Sparsity: 0.0000
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 7, test iteration: 12
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********


  0%|          | 0/1146 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.028
Average Set Size: 23
*********     Round 1: Early Stop Triggered at epoch 1195.    *********
Change in noise variance: 0.4184
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1204.    *********
Change in noise variance: 0.0009
Change in x0: 0.4345
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1213.    *********
Change in noise variance: 0.0000
Change in x0: 0.3464
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1222.    *********
Change in noise variance: 0.0000
Change in x0: 0.2845
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 1240.    *********
Change in noise variance: 0.0000
Change in x0: 0.2936
*********             Velocity Refinement Round 

In [9]:
time_correlation = [get_time_correlation(ground_truth=adata.obs["fucci_time"], estimated=adata.obs["vae_time"])]

In [10]:
adata.layers["velocity"] = adata.layers["vae_velocity"].copy()
scv.tl.velocity_graph(adata, vkey="velocity", n_jobs=1)
scv.tl.velocity_confidence(adata, vkey="velocity")

computing velocity graph (using 1/112 cores)


  0%|          | 0/1146 [00:00<?, ?cells/s]

    finished (0:00:01) --> added 
    'velocity_graph', sparse matrix with cosine correlations (adata.uns)
--> added 'velocity_length' (adata.obs)
--> added 'velocity_confidence' (adata.obs)
--> added 'velocity_confidence_transition' (adata.obs)


## Cross-boundary correctness

In [11]:
vk = VelocityKernel(adata, vkey="vae_velocity").compute_transition_matrix()

cluster_key = "phase"
rep = "X_pca"

score_df = []
for source, target in STATE_TRANSITIONS:
    cbc = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)

    score_df.append(
        pd.DataFrame(
            {
                "State transition": [f"{source} - {target}"] * len(cbc),
                "CBC": cbc,
            }
        )
    )
score_df = pd.concat(score_df)

  0%|          | 0/1146 [00:00<?, ?cell/s]

  0%|          | 0/1146 [00:00<?, ?cell/s]

## Data saving

In [12]:
if SAVE_DATA:
    pd.DataFrame({"time": time_correlation}, index=adata.obs_names).to_parquet(
        path=DATA_DIR / DATASET / "results" / "velovae_vae_correlation.parquet"
    )
    adata.obs[["velocity_confidence"]].to_parquet(
        path=DATA_DIR / DATASET / "results" / "velovae_vae_confidence.parquet"
    )
    score_df.to_parquet(path=DATA_DIR / DATASET / "results" / "velovae_vae_cbc.parquet")