In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [4]:
import scvelo as scv
dataset = 'Bonemarrow'
#adata = scv.datasets.bonemarrow(file_path=f'data/download/{dataset}.h5ad')
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')

In [None]:
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20, keep_raw=True)

In [6]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/discrete/{dataset}'
gene_plot = ['CD44','CELF2','TAOK3']

# Discrete VeloVAE

In [7]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'

torch.manual_seed(2022)
np.random.seed(2022)
dvae = vv.DVAE(adata, 
               tmax=20, 
               dim_z=5, 
               device='cuda:0',
               init_method='steady')

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='tsne')

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Initialization using the steady-state and dynamical models.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Library scale (U): Max=6.09, Min=0.03, Mean=0.98
Library scale (S): Max=8.76, Min=0.35, Mean=1.03
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 28, test iteration: 54
*********       Stage 1: Early Stop Triggered at epoch 404.       *********
*********                      Stage  2                       *********


  0%|          | 0/5028 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.011
Average Set Size: 229
*********       Stage 2: Early Stop Triggered at epoch 576.       *********
*********              Finished. Total Time =   0 h :  4 m : 17 s             *********
Final: Train ELBO = -2033.331,           Test ELBO = -2103.886
       Training MSE = 1.864, Test MSE = 2.644


# Full VB

In [8]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.DVAEFullVB(adata, 
                        tmax=20, 
                        dim_z=5, 
                        device='cuda:0',
                        init_method='steady')

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='tsne')

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'dfullvb', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Initialization using the steady-state and dynamical models.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Library scale (U): Max=6.09, Min=0.03, Mean=0.98
Library scale (S): Max=8.76, Min=0.35, Mean=1.03
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 28, test iteration: 54
*********       Stage 1: Early Stop Triggered at epoch 463.       *********
*********                      Stage  2                       *********


  0%|          | 0/5028 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.003
Average Set Size: 238
*********       Stage 2: Early Stop Triggered at epoch 492.       *********
*********              Finished. Total Time =   0 h :  4 m :  4 s             *********
Final: Train ELBO = -2054.394,           Test ELBO = -2092.880
       Training MSE = 2.143, Test MSE = 2.956


# Evaluation

In [None]:
cluster_edges = [('HSC_1', 'Ery_1'), 
                 ('HSC_1', 'HSC_2'), 
                 ('Ery_1', 'Ery_2')]
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae','dfullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 embed='tsne',
                 save_path=data_path,
                 cluster_edges=cluster_edges)