In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'Pancreas'
import scvelo as scv
#adata = scv.datasets.pancreas()
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')

In [None]:
# Uncomment this if data is not preprocessed
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20)
#adata.write_h5ad(f'data/{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['Cpe','Gng12', 'Ppp3ca', 'Smoc1']

# Vanilla VAE

In [4]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, tmax=20, device='cuda:0')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
********* Early Stop Triggered at epoch 471. *********
*********              Finished. Total Time =   0 h :  2 m :  6 s             *********
Final: Train ELBO = 1434.077,           Test ELBO = 1455.904


# VeloVAE

In [5]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)

vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0')

vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********       Stage 1: Early Stop Triggered at epoch 631.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/3696 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.069
Average Set Size: 107
Finished. Actual Time:   0 h :  0 m :  3 s
*********       Stage 2: Early Stop Triggered at epoch 950.       *********
*********              Finished. Total Time =   0 h :  4 m : 50 s             *********
Final: Train ELBO = 2025.948,           Test ELBO = 2028.874


# Full VB

In [6]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)

full_vb = vv.VAEFullVB(adata, 
                      tmax=20, 
                      dim_z=5, 
                      device='cuda:0')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********       Stage 1: Early Stop Triggered at epoch 562.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/3696 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.061
Average Set Size: 97
Finished. Actual Time:   0 h :  0 m :  6 s
*********       Stage 2: Early Stop Triggered at epoch 949.       *********
*********              Finished. Total Time =   0 h :  9 m : 17 s             *********
Final: Train ELBO = 1993.427,           Test ELBO = 1960.510


# Train a Branching ODE

In [7]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.print_weight()
brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.print_weight()

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  1
Computing type-to-type transition probability
Obtaining the MST in each partition
{0: [1, 2], 1: [2], 2: [], 3: [0, 1, 2, 3, 4, 5, 6, 7], 4: [0, 1, 2], 5: [0, 1, 2, 4, 7], 6: [0, 1, 2, 4, 5, 7], 7: [0, 1, 2, 4]}
Initialization using type-specific dynamical model.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

               Alpha  Beta  Delta  Ductal  Epsilon  Ngn3 high EP  Ngn3 low EP  Pre-endocrine
Alpha            0.0   0.0    0.0     0.0      0.0           0.0          0.0            1.0
Beta             0.0   0.0    0.0     0.0      0.0           0.0          0.0            1.0
Delta            1.0   0.0    0.0     0.0      0.0           0.0          0.0            0.0
Ductal           0.0   0.0    0.0     1.0      0.0           0.0          0.0            0.0
Epsilon          0.0   0.0    0.0     0.0      0.0           0.0          0.0            1.0
Ngn3 high EP     0.0   0.0    0.0     0.0      0.0           0.0          1.0            0.0
Ngn3 low EP      0.0   0.0    0.0     1.0      0.0           0.0          0.0            0.0
Pre-endocrine    0.0   0.0    0.0     0.0      0.0           1.0          0.0            0.0
------------------------ Train a Branching ODE ------------------------
               Alpha  Beta  Delta  Ductal  Epsilon  Ngn3 high EP  Ngn3 low EP  Pre-endocrin

# Evaluation

In [None]:
cluster_edges = [('Ngn3 high EP', 'Pre-endocrine'),
                 ('Pre-endocrine', 'Delta'), 
                 ('Pre-endocrine', 'Beta'), 
                 ('Pre-endocrine','Epsilon'), 
                 ('Pre-endocrine','Alpha')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB', 'BrODE'],
                 ['vanilla', 'velovae', 'fullvb', 'brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)