In [3]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [5]:
dataset = 'BMMC'
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')
adata.obs["clusters"] = adata.obs['celltype.l2'].to_numpy()

In [None]:
vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)

In [6]:
model_path_base = f'checkpoints/{dataset}'
figure_path_base = f'figures/{dataset}'
data_path = f'data/velovae/discrete/{dataset}'
gene_plot = ['SPINK2', 'AZU1', 'MPO', 'LYZ', 'CD74', 'HBB']

# Discrete VAE

In [4]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.DVAE(adata, 
                     tmax=20, 
                     dim_z=5,
                     init_method='random',
                     init_ton_zero=True,
                     device='cuda:0')

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Random Initialization.
Gaussian Prior.
Library scale (U): Max=17.65, Min=0.00, Mean=1.03
Library scale (S): Max=23.48, Min=0.22, Mean=1.13
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
*********       Stage 1: Early Stop Triggered at epoch 197.       *********
*********                      Stage  2                       *********


  0%|          | 0/20400 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.004
Average Set Size: 643
*********       Stage 2: Early Stop Triggered at epoch 293.       *********
*********              Finished. Total Time =   0 h :  7 m : 59 s             *********
Final: Train ELBO = -1160.040,           Test ELBO = -1172.535
       Training MSE = 4.604, Test MSE = 5.425


# FullVB

In [11]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.DVAEFullVB(adata, 
                           tmax=20, 
                           dim_z=5, 
                           init_method='random',
                           init_ton_zero=True,
                           device='cuda:0')

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed="umap")

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Random Initialization.
Gaussian Prior.
Library scale (U): Max=17.65, Min=0.00, Mean=1.03
Library scale (S): Max=23.48, Min=0.22, Mean=1.13
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 112, test iteration: 222
*********       Stage 1: Early Stop Triggered at epoch 347.       *********
*********                      Stage  2                       *********


  0%|          | 0/20400 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.016
Average Set Size: 632
*********       Stage 2: Early Stop Triggered at epoch 373.       *********
*********              Finished. Total Time =   0 h : 11 m :  5 s             *********
Final: Train ELBO = -1190.224,           Test ELBO = -1198.823
       Training MSE = 4.998, Test MSE = 5.271


# Evaluation

In [None]:
cluster_edges = [('HSC','LMPP'),
                 ('LMPP','GMP'),
                 ('GMP','CD14 Mono'),
                 ('CD14 Mono','CD16 Mono'),
                 ('Prog B 1','Prog B 2'), 
                 ('Prog MK','Prog RBC')]
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE', 'Discrete FullVB'],
                 ['dvae', 'dfullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=[],
                 save_path=data_path,
                 cluster_edges=cluster_edges)