In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [2]:
import scvelo as scv
dataset = 'Bonemarrow'
#adata = scv.datasets.bonemarrow(file_path=f'data/download/{dataset}.h5ad')
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')

In [None]:
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20)
#adata.write_h5ad(f'data/{dataset}_pp.h5ad')

In [3]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/continuous/{dataset}'
gene_plot = ['CD44','CELF2','TAOK3']

# Vanilla VAE

In [5]:
figure_path = f'{figure_path_base}/Vanilla'
model_path = f'{model_path_base}/Vanilla'

torch.manual_seed(2022)
np.random.seed(2022)

vanilla_vae = vv.VanillaVAE(adata, 
                            tmax=20, 
                            device='cuda:0',
                            init_method='steady')

vanilla_vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='tsne')
vanilla_vae.save_model(model_path, 'encoder', 'decoder')
vanilla_vae.save_anndata(adata, 'vanilla', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
------------------------- Train a Vanilla VAE -------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
Total Number of Iterations Per Epoch: 28, test iteration: 54
********* Early Stop Triggered at epoch 357. *********
*********              Finished. Total Time =   0 h :  2 m : 10 s             *********
Final: Train ELBO = 2612.867,           Test ELBO = 2564.188


# VeloVAE

In [6]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
vae = vv.VAE(adata, 
             tmax=20, 
             dim_z=5, 
             device='cuda:0',
             init_method='steady')

vae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='tsne')

vae.save_model(model_path, 'encoder', 'decoder')
vae.save_anndata(adata, 'velovae', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 28, test iteration: 54
Epoch 1: Train ELBO = -8018.343, Test ELBO = -63392.723, 	 Total Time =   0 h :  0 m :  2 s
Epoch 100: Train ELBO = 2378.334, Test ELBO = 2292.820, 	 Total Time =   0 h :  0 m : 36 s
Epoch 200: Train ELBO = 2940.919, Test ELBO = 2904.138, 	 Total Time =   0 h :  1 m : 10 s
Epoch 300: Train ELBO = 3159.082, Test ELBO = 3138.598, 	 Total Time =   0 h :  1 m : 44 s
Epoch 400: Train ELBO = 3240.431, Test ELBO = 3235.591

  0%|          | 0/5028 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.035
Average Set Size: 156
Finished. Actual Time:   0 h :  0 m :  4 s
Epoch 542: Train ELBO = 2859.782, Test ELBO = 2716.002, 	 Total Time =   0 h :  3 m : 15 s
Epoch 600: Train ELBO = 3135.189, Test ELBO = 3117.954, 	 Total Time =   0 h :  3 m : 41 s
Epoch 700: Train ELBO = 3178.974, Test ELBO = 3158.185, 	 Total Time =   0 h :  4 m : 30 s
*********       Stage 2: Early Stop Triggered at epoch 783.       *********
*********              Finished. Total Time =   0 h :  5 m : 39 s             *********
Final: Train ELBO = 3174.538,           Test ELBO = 3172.169


# Full VB

In [7]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAEFullVB(adata, 
                       tmax=20, 
                       dim_z=5, 
                       device='cuda:0',
                       init_method='steady')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='tsne')

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'fullvb', data_path, file_name=f'{dataset}.h5ad')

Initialization using the steady-state and dynamical models.
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 28, test iteration: 54
Epoch 1: Train ELBO = -8005.302, Test ELBO = -63441.996, 	 Total Time =   0 h :  0 m :  2 s
Epoch 100: Train ELBO = 2387.361, Test ELBO = 2281.423, 	 Total Time =   0 h :  0 m : 41 s
Epoch 200: Train ELBO = 2937.164, Test ELBO = 2865.263, 	 Total Time =   0 h :  1 m : 20 s
Epoch 300: Train ELBO = 3136.198, Test ELBO = 3098.264, 	 Total Time =   0 h :  1 m : 59 s
Epoch 400: Train ELBO = 3198.658, Test ELBO = 3164.523

  0%|          | 0/5028 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.031
Average Set Size: 179
Finished. Actual Time:   0 h :  0 m :  4 s
Epoch 493: Train ELBO = 2858.462, Test ELBO = 2781.656, 	 Total Time =   0 h :  3 m : 24 s
Epoch 500: Train ELBO = 3042.335, Test ELBO = 3007.220, 	 Total Time =   0 h :  3 m : 31 s
*********       Stage 2: Early Stop Triggered at epoch 579.       *********
*********              Finished. Total Time =   0 h :  4 m : 41 s             *********
Final: Train ELBO = 3104.790,           Test ELBO = 3071.202


# Branching ODE

In [8]:
figure_path = f'{figure_path_base}/BrODE'
model_path = f'{model_path_base}/BrODE'

torch.manual_seed(2022)
np.random.seed(2022)
brode = vv.BrODE(adata, 'clusters', 'fullvb_time', 'fullvb_z', param_key='fullvb')

brode.print_weight()

brode.train(adata, 'fullvb_time', 'clusters', plot=False, gene_plot=gene_plot, figure_path=figure_path)

brode.save_model(model_path, 'brode')
brode.save_anndata(adata, 'brode', data_path, file_name=f'{dataset}.h5ad')

vv.plot_transition_graph(adata, save=f'{figure_path_base}/transition.png')

Graph Partition
Number of partitions:  1
Computing type-to-type transition probability
Obtaining the MST in each partition
Initialization using type-specific dynamical model.
            CLP  DCs  Ery_1  Ery_2  HSC_1  HSC_2  Mega  Mono_1  Mono_2  Precursors
CLP         0.0  0.0    0.0    0.0    1.0    0.0   0.0     0.0     0.0         0.0
DCs         0.0  0.0    0.0    0.0    0.0    0.0   0.0     0.0     1.0         0.0
Ery_1       0.0  0.0    0.0    1.0    0.0    0.0   0.0     0.0     0.0         0.0
Ery_2       0.0  0.0    0.0    1.0    0.0    0.0   0.0     0.0     0.0         0.0
HSC_1       0.0  0.0    0.0    0.0    0.0    1.0   0.0     0.0     0.0         0.0
HSC_2       0.0  0.0    0.0    0.0    0.0    0.0   0.0     0.0     0.0         1.0
Mega        0.0  0.0    0.0    0.0    1.0    0.0   0.0     0.0     0.0         0.0
Mono_1      0.0  0.0    1.0    0.0    0.0    0.0   0.0     0.0     0.0         0.0
Mono_2      0.0  0.0    0.0    0.0    0.0    0.0   0.0     1.0     0.0        

# Evaluation

In [None]:
cluster_edges = [('HSC_1', 'Ery_1'), 
                 ('HSC_1', 'HSC_2'), 
                 ('Ery_1', 'Ery_2')]
vv.post_analysis(adata,
                 'eval',
                 ['Vanilla VAE', 'VeloVAE', 'FullVB', 'BrODE'],
                 ['vanilla', 'velovae', 'fullvb', 'brode'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,3),
                 plot_type=['all'],
                 embed='tsne',
                 save_path=data_path,
                 cluster_edges=cluster_edges)