In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'Braindev'
adata = anndata.read_h5ad('data/Braindev_pp.h5ad')
adata.obs['clusters'] = adata.obs['Class'].to_numpy()

In [None]:
#Uncomment this if data is not preprocessed
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)
#adata.write_h5ad(f'{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['Aldh1l1', 'Mapt', 'Myt1l', 'Lum']

# Informative Time Prior

In [4]:
capture_time = adata.obs['Age'].to_numpy()
tprior = np.array([float(x[1:]) for x in capture_time])
adata.obs['tprior'] = tprior - tprior.min()

# Vanilla VAE

In [6]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='tprior',
                            init_key='tprior',
                            tprior='tprior')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)
vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
********* Early Stop Triggered at epoch 142. *********
*********              Finished. Total Time =   0 h :  4 m : 43 s             *********
Final: Train ELBO = 963.353,           Test ELBO = 991.745


# VeloVAE

In [7]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='tprior',
             init_key='tprior',
             tprior='tprior',
             init_ton_zero=False)

vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
*********       Stage 1: Early Stop Triggered at epoch 311.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.002
Average Set Size: 748
Finished. Actual Time:   0 h :  0 m : 37 s
*********       Stage 2: Early Stop Triggered at epoch 405.       *********
*********              Finished. Total Time =   0 h : 16 m :  0 s             *********
Final: Train ELBO = 2483.334,           Test ELBO = 2490.559


# Full VB

In [8]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='tprior',
                       init_key='tprior',
                       tprior='tprior',
                       init_ton_zero=False)

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using prior time.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Using informative time prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
*********       Stage 1: Early Stop Triggered at epoch 227.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.002
Average Set Size: 888
Finished. Actual Time:   0 h :  0 m : 38 s
*********       Stage 2: Early Stop Triggered at epoch 291.       *********
*********              Finished. Total Time =   0 h : 13 m : 37 s             *********
Final: Train ELBO = 2461.378,           Test ELBO = 2465.405


# Branching ODE

In [9]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.print_weight()

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  1
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

                 Ependymal  Fibroblast  Glioblast  Mesenchyme  Neural crest  Neural tube  Neuroblast  Neuron  Oligodendrocyte  Radial glia
Ependymal              0.0         0.0        0.0         0.0           0.0          0.0         0.0     0.0              0.0          1.0
Fibroblast             0.0         0.0        0.0         1.0           0.0          0.0         0.0     0.0              0.0          0.0
Glioblast              0.0         0.0        0.0         0.0           0.0          0.0         0.0     0.0              0.0          1.0
Mesenchyme             0.0         0.0        0.0         0.0           1.0          0.0         0.0     0.0              0.0          0.0
Neural crest           0.0         0.0        0.0         0.0           0.0          1.0         0.0     0.0              0.0          0.0
Neural tube            0.0         0.0        0.0         0.0           0.0          1.0         0.0     0.0              0.0          0.0
Neuroblast             0.0 

# Evaluation

In [None]:
cluster_edges = [('Neural tube','Radial glia'),
                 ('Radial glia', 'Neuroblast'),
                 ('Radial glia', 'Glioblast'),
                 ('Radial glia', 'Oligodendrocyte'),
                 ('Radial glia', 'Ependymal'),
                 ('Neural crest', 'Mesenchyme'),
                 ('Mesenchyme','Fibroblast')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE','VeloVAE','FullVB','BrODE'],
                 ['vanilla','velovae','fullvb','brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=figure_path,
                 cluster_edges=cluster_edges)