In [3]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [4]:
dataset = 'Braindev'
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')
adata.obs['clusters'] = adata.obs['Class'].to_numpy()

In [None]:
vv.preprocess(adata, n_gene=1000, min_shared_counts=20, compute_umap=True, keep_raw=True)

In [5]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/discrete/{dataset}'
gene_plot = ['Mapt', 'Tmsb10', 'Fabp7', 'Npm1']

# Informative Time Prior

In [6]:
capture_time = adata.obs['Age'].to_numpy()
tprior = np.array([float(x[1:]) for x in capture_time])
adata.obs['tprior'] = tprior - tprior.min()

## Discrete VAE

In [7]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'


torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.DVAE(adata, 
                     tmax=20, 
                     dim_z=5, 
                     device='cuda:0',
                     init_method='tprior',
                     init_key='tprior',
                     tprior='tprior',
                     init_ton_zero=False)

dvae.train(adata, plot=True, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 5, 0
Initialization using prior time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
Library scale (U): Max=10.47, Min=0.00, Mean=0.94
Library scale (S): Max=17.49, Min=0.16, Mean=1.00
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
Epoch 1: Train ELBO = -2553.180, Test ELBO = -19347.024, 	 Total Time =   0 h :  0 m : 10 s
Epoch 100: Train ELBO = -1296.207, Test ELBO = -1306.282, 	 Total Time =   0 h :  3 m : 27 s
*********       Stage 1: Early Stop Triggered at epoch 177.       *********
*********                 

  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.002
Average Set Size: 788
Epoch 178: Train ELBO = -1317.090, Test ELBO = -2312.852, 	 Total Time =   0 h :  6 m : 43 s
*********       Stage 2: Early Stop Triggered at epoch 195.       *********
*********              Finished. Total Time =   0 h :  9 m :  7 s             *********
Final: Train ELBO = -1285.481,           Test ELBO = -1291.189
       Training MSE = 2.424, Test MSE = 2.533


# Discrete FullVB

In [8]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)
dfullvb = vv.model.DVAEFullVB(adata, 
                              tmax=20, 
                              dim_z=5, 
                              device='cuda:0',
                              init_method="tprior",
                              init_key="tprior",
                              tprior="tprior",
                              init_ton_zero=False)

dfullvb.train(adata, plot=True, gene_plot=gene_plot, figure_path=figure_path)

dfullvb.save_model(model_path, 'encoder', 'decoder')
dfullvb.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 5, 0
Initialization using prior time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
Library scale (U): Max=10.47, Min=0.00, Mean=0.94
Library scale (S): Max=17.49, Min=0.16, Mean=1.00
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
Epoch 1: Train ELBO = -2590.209, Test ELBO = -19448.491, 	 Total Time =   0 h :  0 m : 10 s
Epoch 100: Train ELBO = -1449.742, Test ELBO = -1441.458, 	 Total Time =   0 h :  4 m : 19 s
Epoch 200: Train ELBO = -1430.903, Test ELBO = -1438.991, 	 Total Time =   0 h :  8 m : 15 s
Epoch 300

  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.002
Average Set Size: 782
Epoch 954: Train ELBO = -1332.143, Test ELBO = -5388.501, 	 Total Time =   0 h : 37 m : 59 s
Epoch 1000: Train ELBO = -1321.808, Test ELBO = -1320.518, 	 Total Time =   0 h : 40 m :  9 s
*********       Stage 2: Early Stop Triggered at epoch 1009.       *********
*********              Finished. Total Time =   0 h : 42 m : 25 s             *********
Final: Train ELBO = -1324.847,           Test ELBO = -1325.150
       Training MSE = 2.575, Test MSE = 2.568


# Evaluation

In [None]:
cluster_edges = [('Neural tube','Radial glia'),
                 ('Radial glia', 'Neuroblast'),
                 ('Radial glia', 'Glioblast'),
                 ('Radial glia', 'Oligodendrocyte'),
                 ('Radial glia', 'Ependymal'),
                 ('Neural crest', 'Mesenchyme'),
                 ('Mesenchyme','Fibroblast')]
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae','dfullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)