In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'HIO'
adata = anndata.read_h5ad('data/HIO_pp.h5ad')

In [None]:
#marker_genes = ['S100B','PLP1','STMN2','ELAVL4','CDH5','KDR','ECSCR','CLDN5','COL1A1','COL1A2',\
#                'DCN','ACTA2','TAGLN','ACTG2','MYLK','EPCAM','CDH1','CDX2','CLDN4']
#vv.preprocess(adata, 
#              n_gene=2000, 
#              min_shared_counts=20, 
#              genes_retain=marker_genes, 
#              compute_umap=True, 
#              keep_raw=True)

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['PLP1','ECSCR', 'COL1A1', 'EPCAM']

# Informative Time Prior

In [4]:
day = adata.obs['Day']
tprior = np.array([float(x[1:]) for x in day])
adata.obs['tprior'] = tprior - tprior.min()

# Vanilla VAE

In [5]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='tprior',
                            init_key='tprior',
                            tprior='tprior')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vanilla_vae.save_model(model_path, 'encoder','decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2002 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2002 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 149, test iteration: 296
********* Early Stop Triggered at epoch 106. *********
*********              Finished. Total Time =   0 h :  3 m : 25 s             *********
Final: Train ELBO = 3643.568,           Test ELBO = 3639.790


# VeloVAE

In [6]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='tprior',
             init_key='tprior',
             tprior='tprior',
             init_ton_zero=False)


vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2002 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2002 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 149, test iteration: 296
*********       Stage 1: Early Stop Triggered at epoch 146.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/27086 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.017
Average Set Size: 1031
Finished. Actual Time:   0 h :  0 m : 35 s
*********       Stage 2: Early Stop Triggered at epoch 168.       *********
*********              Finished. Total Time =   0 h :  6 m : 31 s             *********
Final: Train ELBO = 3886.803,           Test ELBO = 3879.228


# Full VB

In [7]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='tprior',
                       init_key='tprior',
                       tprior='tprior',
                       init_ton_zero=False)

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder','decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2002 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2002 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 149, test iteration: 296
*********       Stage 1: Early Stop Triggered at epoch 185.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/27086 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.043
Average Set Size: 951
Finished. Actual Time:   0 h :  0 m : 34 s
*********       Stage 2: Early Stop Triggered at epoch 213.       *********
*********              Finished. Total Time =   0 h :  8 m : 58 s             *********
Final: Train ELBO = 3871.534,           Test ELBO = 3860.804


# Evaluation

In [9]:
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB'],
                 ['vanilla', 'velovae', 'fullvb'],
                 compute_metrics=True,
                 plot_type=[],
                 genes=gene_plot,
                 grid_size=(1,4),
                 save_path=data_path)

---     Computing Peformance Metrics     ---
Dataset Size: 27086 cells, 2002 genes
---   Plotting  Results   ---


Unnamed: 0,Vanilla VAE,VeloVAE,FullVB
MSE Train,2.591,0.279,0.388
MSE Test,2.625,0.294,0.413
MAE Train,0.208,0.112,0.12
MAE Test,0.208,0.113,0.122
LL Train,3648.673,3904.186,3890.291
LL Test,3644.886,3896.514,3882.405
corr,-0.336,0.884,0.889
In-Cluster Coherence,0.925,0.615,0.791
