In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [2]:
dataset = 'IPSC'
adata = anndata.read_h5ad(f'data/{dataset}_pp.h5ad')

In [None]:
vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)

In [3]:
model_path_base = f'checkpoints/{dataset}_discrete_notime'
figure_path_base = f'figures/{dataset}_discrete_notime'
data_path = f'data/velovae/discrete/{dataset}_notime'
gene_plot = ['Vim','Nr2f1', 'Krt7', 'H19']

# Informative Time Prior

In [4]:
tprior = adata.obs["day"].to_numpy()
tprior = np.array([float(x) for x in tprior])
adata.obs["tprior"] = tprior - tprior.min()

# Discrete VeloVAE

In [5]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.DVAE(adata, 
                     tmax=20, 
                     dim_z=5, 
                     device='cuda:0',
                     init_method='tprior',
                     init_key='tprior',
                     tprior=None,
                     init_ton_zero=False)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='embed')

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 3, 0
Initialization using prior time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Library scale (U): Max=10.35, Min=0.00, Mean=0.89
Library scale (S): Max=20.01, Min=0.13, Mean=1.06
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 146, test iteration: 290
*********       Stage 1: Early Stop Triggered at epoch 199.       *********
*********                      Stage  2                       *********


  0%|          | 0/26682 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 741
*********       Stage 2: Early Stop Triggered at epoch 440.       *********
*********              Finished. Total Time =   0 h : 22 m : 32 s             *********
Final: Train ELBO = -1945.400,           Test ELBO = -1988.029
       Training MSE = 31.827, Test MSE = 34.611


# Discrete Full VB

In [8]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.DVAEFullVB(adata, 
                           tmax=20, 
                           dim_z=5, 
                           device='cuda:0',
                           init_method='tprior',
                           init_key='tprior',
                           tprior=None,
                           init_ton_zero=False)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='embed')

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 3, 0
Initialization using prior time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Gaussian Prior.
Library scale (U): Max=10.35, Min=0.00, Mean=0.89
Library scale (S): Max=20.01, Min=0.13, Mean=1.06
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 146, test iteration: 290
*********       Stage 1: Early Stop Triggered at epoch 318.       *********
*********                      Stage  2                       *********


  0%|          | 0/26682 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.000
Average Set Size: 759
*********       Stage 2: Early Stop Triggered at epoch 501.       *********
*********              Finished. Total Time =   0 h : 20 m :  2 s             *********
Final: Train ELBO = -1961.476,           Test ELBO = -1998.037
       Training MSE = 28.089, Test MSE = 31.716


# Evaluation

In [None]:
cluster_edges = [('MET','Epithelial'),
                 ('Epithelial','IPS'),
                 ('Epithelial','Neural'),
                 ('Epithelial','Trophoblast'),
                 ('Epithelial','Stromal')]
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae','dfullvb'],
                 compute_metrics=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 plot_type=['all'],
                 save_path=data_path,
                 cluster_edges=cluster_edges)