In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [2]:
dataset = 'IO_EU'
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')
adata.obs['clusters'] = adata.obs['cell_type'].to_numpy()
import pandas as pd
adata.var.index = pd.Index(adata.var['gene'])

In [None]:
vv.preprocess(adata, n_gene=2000, min_shared_counts=20)
adata.write_h5ad(f'data/{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['Lgr5', 'Apoa1', 'Dgat1', 'Gsta4']

# Vanilla VAE

In [4]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='steady')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)
vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
********* Early Stop Triggered at epoch 481. *********
*********              Finished. Total Time =   0 h :  1 m : 58 s             *********
Final: Train ELBO = -601.520,           Test ELBO = -662.537


# VeloVAE

In [5]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='steady')

vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********       Stage 1: Early Stop Triggered at epoch 513.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/3829 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.142
Average Set Size: 179
Finished. Actual Time:   0 h :  0 m :  3 s
*********       Stage 2: Early Stop Triggered at epoch 822.       *********
*********              Finished. Total Time =   0 h :  3 m : 41 s             *********
Final: Train ELBO = 3518.875,           Test ELBO = 3487.642


# Full VB

In [6]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='steady')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********       Stage 1: Early Stop Triggered at epoch 490.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/3829 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.090
Average Set Size: 167
Finished. Actual Time:   0 h :  0 m :  3 s
*********       Stage 2: Early Stop Triggered at epoch 900.       *********
*********              Finished. Total Time =   0 h :  4 m : 40 s             *********
Final: Train ELBO = 3433.425,           Test ELBO = 3356.509


# Branching ODE

In [7]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z', param_key='fullvb')

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  2
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = um.true_divide(


------------------------ Train a Branching ODE ------------------------
                            Enterocytes  Enteroendocrine  Enteroendocrine progenitor  Goblet cells  Paneth cells  Stem cells  TA cells  Tuft cells
Enterocytes                         0.0              0.0                         0.0           0.0           0.0         0.0       1.0         0.0
Enteroendocrine                     0.0              0.0                         1.0           0.0           0.0         0.0       0.0         0.0
Enteroendocrine progenitor          0.0              0.0                         0.0           0.0           0.0         1.0       0.0         0.0
Goblet cells                        0.0              0.0                         0.0           0.0           0.0         1.0       0.0         0.0
Paneth cells                        0.0              0.0                         0.0           1.0           0.0         0.0       0.0         0.0
Stem cells                          0.0       

# Evaluation

In [None]:
cluster_edges = [('Stem cells', 'TA cells'), ('Stem cells', 'Globet Cells')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE','VeloVAE','FullVB','BrODE'],
                 ['vanilla','velovae','fullvb','brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)