In [1]:
import anndata
import numpy as np
import sys
import torch
import scvelo as scv
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'Erythroid'
adata = scv.datasets.gastrulation_erythroid(file_path='data/download/Erythroid.h5ad')
adata.obs["clusters"] = adata.obs["celltype"].to_numpy()

In [3]:
vv.preprocess(adata, n_gene=1000, min_shared_counts=20)
adata.write_h5ad('data/Erythroid_pp.h5ad')

Filtered out 47628 genes that are detected 20 counts (shared).
Normalized count data: X, spliced, unspliced.
Exctracted 1000 highly variable genes.
Logarithmized X.
Computing the KNN graph based on X_pca
computing neighbors
    finished (0:02:29) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:01) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
Keep raw unspliced/spliced count data.


In [4]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['Smim1', 'Blvrb', 'Hba-x', 'Lmo2']

# Informative Time Prior

In [5]:
capture_time = adata.obs['stage'].to_numpy()
tprior = np.array([float(x[1:]) for x in capture_time])
adata.obs['tprior'] = tprior - tprior.min()

# Vanilla VAE

In [6]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='tprior',
                            init_key='tprior',
                            tprior='tprior')

vanilla_vae.train(adata, config={}, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 54, test iteration: 106
********* Early Stop Triggered at epoch 189. *********
*********              Finished. Total Time =   0 h :  1 m : 45 s             *********
Final: Train ELBO = 448.076,           Test ELBO = 446.249


# VeloVAE

In [7]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)

vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='tprior',
             init_key='tprior',
             tprior='tprior',
             init_ton_zero=False)


config = {
    'early_stop':9, 
    'train_ton':False
}
vae.train(adata, config=config, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 54, test iteration: 106
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/9815 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.011
Average Set Size: 314
Finished. Actual Time:   0 h :  0 m :  9 s
*********              Finished. Total Time =   0 h : 20 m : 38 s             *********
Final: Train ELBO = 524.388,           Test ELBO = 517.612


# FullVB

In [8]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)

full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='tprior',
                       init_key='tprior',
                       tprior='tprior',
                       init_ton_zero=False)

config = {
    'early_stop':9, 
    'train_ton':False,
}
full_vb.train(adata, config=config, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 54, test iteration: 106
*********       Stage 1: Early Stop Triggered at epoch 377.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/9815 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.014
Average Set Size: 366
Finished. Actual Time:   0 h :  0 m :  9 s
*********              Finished. Total Time =   0 h : 17 m : 19 s             *********
Final: Train ELBO = 517.093,           Test ELBO = 511.114


# Branching ODE

In [9]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  1
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

------------------------ Train a Branching ODE ------------------------
                     Blood progenitors 1  Blood progenitors 2  Erythroid1  Erythroid2  Erythroid3
Blood progenitors 1                  1.0                  0.0         0.0         0.0         0.0
Blood progenitors 2                  1.0                  0.0         0.0         0.0         0.0
Erythroid1                           0.0                  1.0         0.0         0.0         0.0
Erythroid2                           0.0                  0.0         1.0         0.0         0.0
Erythroid3                           0.0                  0.0         0.0         1.0         0.0
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training         

# Evaluation

In [10]:
cluster_edges = [('Blood progenitors 1', 'Blood progenitors 2'),
                 ('Blood progenitors 2', 'Erythroid1'),
                 ('Erythroid1', 'Erythroid2'),
                 ('Erythroid2', 'Erythroid3')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB', 'BrODE'],
                 ['vanilla', 'velovae', 'fullvb', 'brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=[],
                 save_path=figure_path_base,
                 cluster_edges=cluster_edges)

Computing velocity embedding using scVelo
computing velocity graph (using 1/32 cores)


  0%|          | 0/9815 [00:00<?, ?cells/s]

    finished (0:00:45) --> added 
    'vanilla_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:02) --> added
    'vanilla_velocity_umap', embedded velocity vectors (adata.obsm)
Computing velocity embedding using scVelo
computing velocity graph (using 1/32 cores)


  0%|          | 0/9815 [00:00<?, ?cells/s]

    finished (0:00:46) --> added 
    'velovae_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:02) --> added
    'velovae_velocity_umap', embedded velocity vectors (adata.obsm)
Computing velocity embedding using scVelo
computing velocity graph (using 1/32 cores)


  0%|          | 0/9815 [00:00<?, ?cells/s]

    finished (0:00:46) --> added 
    'fullvb_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:02) --> added
    'fullvb_velocity_umap', embedded velocity vectors (adata.obsm)
Computing velocity embedding using scVelo
computing velocity graph (using 1/32 cores)


  0%|          | 0/9815 [00:00<?, ?cells/s]

    finished (0:00:45) --> added 
    'brode_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:02) --> added
    'brode_velocity_umap', embedded velocity vectors (adata.obsm)
---     Computing Peformance Metrics     ---
Dataset Size: 9815 cells, 1000 genes
---   Plotting  Results   ---


Unnamed: 0,Vanilla VAE,VeloVAE,FullVB,BrODE
MSE Train,22.185,15.405,27.464,21.985
MSE Test,23.773,15.743,27.07,21.843
MAE Train,0.513,0.341,0.416,0.496
MAE Test,0.519,0.348,0.419,0.493
LL Train,452.223,536.54,533.618,2312.738
LL Test,450.413,529.765,531.745,2351.34
corr,0.803,0.883,0.863,0.863
Cross-Boundary Direction Correctness (embed),0.832,0.306,0.483,0.436
Cross-Boundary Direction Correctness,0.558,0.169,0.309,0.471
In-Cluster Coherence,0.983,0.561,0.633,0.943
