In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'BMMC'
adata = anndata.read_h5ad('data/BMMC_pp.h5ad')
adata.obs["clusters"] = adata.obs['celltype.l2'].to_numpy()

In [None]:
# Uncomment this if data has not been preprocessed
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)
#adata.write_h5ad(f'{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['SPINK2', 'AZU1', 'MPO', 'LYZ', 'CD74', 'HBB']

# Vanilla VAE

In [4]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2023)
np.random.seed(2023)

vanilla_vae = vv.VanillaVAE(adata, tmax=20, device='cuda:0')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
********* Early Stop Triggered at epoch 121. *********
*********              Finished. Total Time =   0 h :  5 m : 10 s             *********
Final: Train ELBO = 3495.039,           Test ELBO = 3506.253


# VeloVAE

In [5]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)

vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0')

vae.train(adata, gene_plot=gene_plot, plot=False, figure_path=figure_path)
vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
*********       Stage 1: Early Stop Triggered at epoch 217.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/20400 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.269
Average Set Size: 576
Finished. Actual Time:   0 h :  0 m : 19 s
*********       Stage 2: Early Stop Triggered at epoch 372.       *********
*********              Finished. Total Time =   0 h : 12 m : 16 s             *********
Final: Train ELBO = 3942.816,           Test ELBO = 3945.155


# Full VB

In [6]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)

full_vb = vv.VAEFullVB(adata, 
                      tmax=20, 
                      dim_z=5, 
                      device='cuda:0')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
*********       Stage 1: Early Stop Triggered at epoch 256.       *********
*********                      Stage  2                       *********
Cell-wise KNN Estimation.


  0%|          | 0/20400 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.270
Average Set Size: 609
Finished. Actual Time:   0 h :  0 m : 20 s
*********       Stage 2: Early Stop Triggered at epoch 369.       *********
*********              Finished. Total Time =   0 h : 13 m : 43 s             *********
Final: Train ELBO = 3917.174,           Test ELBO = 3918.884


# Train a Branching ODE

In [7]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z')

brode.print_weight()

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  2
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

                CD14 Mono  CD16 Mono  CD4 Memory  CD4 Naive  CD56 bright NK  CD8 Effector_1  CD8 Effector_2  CD8 Memory_1  CD8 Memory_2  CD8 Naive  GMP  HSC  LMPP  MAIT  Memory B   NK  Naive B  \
CD14 Mono             0.0        0.0         0.0        0.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD16 Mono             1.0        0.0         0.0        0.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD4 Memory            0.0        0.0         0.0        1.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD4 Naive             0.0        0.0         0.0        1.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD56 bright NK      

                CD14 Mono  CD16 Mono  CD4 Memory  CD4 Naive  CD56 bright NK  CD8 Effector_1  CD8 Effector_2  CD8 Memory_1  CD8 Memory_2  CD8 Naive  GMP  HSC  LMPP  MAIT  Memory B   NK  Naive B  \
CD14 Mono             0.0        0.0         0.0        0.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD16 Mono             1.0        0.0         0.0        0.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD4 Memory            0.0        0.0         0.0        1.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD4 Naive             0.0        0.0         0.0        1.0             0.0             0.0             0.0           0.0           0.0        0.0  0.0  0.0   0.0   0.0       0.0  0.0      0.0   
CD56 bright NK      

*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
*********           Early Stop Triggered at epoch 165.            *********
*********              Finished. Total Time =   0 h : 30 m : 56 s             *********


# Evaluation

In [None]:
cluster_edges = [('HSC','LMPP'),
                 ('LMPP','GMP'),
                 ('GMP','CD14 Mono'),
                 ('CD14 Mono','CD16 Mono'),
                 ('Prog DC','cDc2'),
                 ('Prog B 1','Prog B 2'), 
                 ('Prog MK','Prog RBC')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB'],
                 ['vanilla', 'velovae', 'fullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(2,3),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)