In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'HIO'
adata = anndata.read_h5ad('data/HIO_pp.h5ad')

In [None]:
#marker_genes = ['S100B','PLP1','STMN2','ELAVL4','CDH5','KDR','ECSCR','CLDN5','COL1A1','COL1A2',\
#                'DCN','ACTA2','TAGLN','ACTG2','MYLK','EPCAM','CDH1','CDX2','CLDN4']
#vv.preprocess(adata, 
#              n_gene=2000, 
#              min_shared_counts=20, 
#              genes_retain=marker_genes, 
#              compute_umap=True, 
#              keep_raw=True)

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/discrete/{dataset}'
gene_plot = ['PLP1', 'ECSCR', 'COL1A1', 'EPCAM']

# Informative Time Prior

In [4]:
day = adata.obs['Day']
tprior = np.array([float(x[1:]) for x in day])
adata.obs['tprior'] = tprior - tprior.min()

# Discrete VeloVAE

In [5]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'

torch.manual_seed(2022)
np.random.seed(2022)
dvae = vv.DVAE(adata, 
               tmax=20, 
               dim_z=5, 
               device='cuda:0',
               init_method='tprior',
               init_key='tprior',
               tprior='tprior',
               init_ton_zero=False)


dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 18, 0
Initialization using prior time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
Library scale (U): Max=24.07, Min=0.01, Mean=0.87
Library scale (S): Max=48.52, Min=0.02, Mean=1.10
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 149, test iteration: 296
*********       Stage 1: Early Stop Triggered at epoch 273.       *********
*********                      Stage  2                       *********


  0%|          | 0/27086 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.015
Average Set Size: 861
*********       Stage 2: Early Stop Triggered at epoch 329.       *********
*********              Finished. Total Time =   0 h : 13 m : 13 s             *********
Final: Train ELBO = -1172.552,           Test ELBO = -1191.927
       Training MSE = 12.341, Test MSE = 16.571


# Full VB

In [6]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)
dfull_vb = vv.DVAEFullVB(adata, 
                         tmax=20, 
                         dim_z=5, 
                         device='cuda:0',
                         init_method='tprior',
                         init_key='tprior',
                         tprior='tprior',
                         init_ton_zero=False)

dfull_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dfull_vb.save_model(model_path, 'encoder','decoder')
dfull_vb.save_anndata(adata, 'dfullvb', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 18, 0
Initialization using prior time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2002 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
Library scale (U): Max=24.07, Min=0.01, Mean=0.87
Library scale (S): Max=48.52, Min=0.02, Mean=1.10
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 149, test iteration: 296
*********       Stage 1: Early Stop Triggered at epoch 883.       *********
*********                      Stage  2                       *********


  0%|          | 0/27086 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.002
Average Set Size: 748
*********       Stage 2: Early Stop Triggered at epoch 907.       *********
*********              Finished. Total Time =   0 h : 35 m : 53 s             *********
Final: Train ELBO = -1188.732,           Test ELBO = -1201.923
       Training MSE = 11.048, Test MSE = 15.144


# Evaluation

In [None]:
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae', 'dfullvb'],
                 compute_metrics=False,
                 plot_type=['all'],
                 genes=gene_plot,
                 grid_size=(1,4),
                 save_path=data_path)