In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [2]:
dataset = "Hindbrain_pons"
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')
adata.obs['clusters']=adata.obs['celltype'].to_numpy()

In [None]:
vv.preprocess(adata, n_gene=2000, min_shared_counts=20)
adata.write_h5ad(f'data/{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['Ptprz1','Enpp6','Rras2','Mal']

# Vanilla VAE

In [5]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='random')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)
vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Random Initialization.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 35, test iteration: 68
********* Early Stop Triggered at epoch 230. *********
*********              Finished. Total Time =   0 h :  1 m : 10 s             *********
Final: Train ELBO = -846.607,           Test ELBO = -823.433


# VeloVAE

In [6]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='steady')

vae.train(adata, plot=True, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 35, test iteration: 68
Epoch 1: Train ELBO = -7027.202, Test ELBO = -817273.625, 	 Total Time =   0 h :  0 m :  2 s
Epoch 100: Train ELBO = 647.939, Test ELBO = 659.443, 	 Total Time =   0 h :  0 m : 36 s
Epoch 200: Train ELBO = 1049.552, Test ELBO = 1042.311, 	 Total Time =   0 h :  1 m : 11 s
Epoch 300: Train ELBO = 1106.680, Test ELBO = 1107.914, 	 Total Time =   0 h :  1 m : 45 s
Epoch 400: Train ELBO = 1138.378, Test ELBO = 1141.417,

  0%|          | 0/6253 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.422
Average Set Size: 116
Finished. Actual Time:   0 h :  0 m :  3 s
Epoch 608: Train ELBO = 845.005, Test ELBO = 755.300, 	 Total Time =   0 h :  3 m : 38 s
Epoch 700: Train ELBO = 1107.348, Test ELBO = 1097.820, 	 Total Time =   0 h :  4 m : 14 s
Epoch 800: Train ELBO = 1103.250, Test ELBO = 1116.254, 	 Total Time =   0 h :  4 m : 52 s
Epoch 900: Train ELBO = 1104.767, Test ELBO = 1134.724, 	 Total Time =   0 h :  5 m : 31 s
*********       Stage 2: Early Stop Triggered at epoch 938.       *********
*********              Finished. Total Time =   0 h :  6 m :  8 s             *********
Final: Train ELBO = 1113.453,           Test ELBO = 1117.616


# Full VB

In [7]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='steady')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/1000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 35, test iteration: 68
*********       Stage 1: Early Stop Triggered at epoch 927.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/6253 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.407
Average Set Size: 106
Finished. Actual Time:   0 h :  0 m :  4 s
*********       Stage 2: Early Stop Triggered at epoch 1170.       *********
*********              Finished. Total Time =   0 h :  7 m : 54 s             *********
Final: Train ELBO = 1100.022,           Test ELBO = 1095.810


# Branching ODE

In [8]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  2
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.
Estimating ODE parameters...


  0%|          | 0/1000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/1000 [00:00<?, ?it/s]

------------------------ Train a Branching ODE ------------------------
       COPs  MFOLs  NFOLs  OPCs
COPs    0.0    0.0    1.0   0.0
MFOLs   0.0    1.0    0.0   0.0
NFOLs   0.0    1.0    0.0   0.0
OPCs    0.0    0.0    0.0   1.0
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 35, test iteration: 68
*********           Early Stop Triggered at epoch 241.            *********
*********              Finished. Total Time =   0 h :  4 m : 50 s             *********


# Evaluation

In [None]:
cluster_edges = [('COPs', 'NFOLs'), ('NFOLs', 'MFOLs')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE','VeloVAE','FullVB','BrODE'],
                 ['vanilla','velovae','fullvb','brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)