In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv
%load_ext autoreload
%autoreload 2

In [2]:
dataset = 'Pancreas'
root = "/scratch/blaauw_root/blaauw1/gyichen"
adata = anndata.read_h5ad(f'{root}/data/{dataset}_pp.h5ad')

In [None]:
#uncomment the next line if data is not preprocessed
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20)

In [3]:
model_path_base = f'{root}/checkpoints/{dataset}'
figure_path_base = f'{root}/figures/{dataset}'
data_path = f'{root}/data/velovae/discrete/{dataset}'
gene_plot = ['Cpe','Gng12', 'Ppp3ca', 'Smoc1']

# Discrete VAE

In [4]:
figure_path = f'{figure_path_base}/Discrete_VeloVAE'
model_path = f'{model_path_base}/Discrete_VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.VAE(adata, 
              tmax=20, 
              dim_z=5,
              init_method='steady',
              device='cuda:0',
              discrete=True)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 901 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.51, 0.7331313053369634), (0.49, 0.34843595768045504)
(0.57, 0.812504473951968), (0.43, 0.44401169136605384)
(0.45, 0.7673639353636807), (0.55, 0.31774595674199757)
KS-test result: [0. 0. 0.]
Initial induction: 1164, repression: 836/2000
Learning Rate based on Data Sparsity: 0.0002
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********                      Stage  2                       *********
*********             Velocit

  0%|          | 0/3696 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.029
Average Set Size: 80
*********     Round 1: Early Stop Triggered at epoch 1235.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1428.    *********
Change in x0: 0.2135
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1756.    *********
Change in x0: 0.1493
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 2004.    *********
Change in x0: 0.1248
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 2048.    *********
Change in x0: 0.0799
*********             Velocity Refinement Round 6             *********
*********     Round 6: Early Stop Triggered at epoch 2153.    *********
Change in x0: 0.0711
*********             Velocity Refinement Round 

# Discrete Full VB

In [5]:
figure_path = f'{figure_path_base}/Discrete_FullVB'
model_path = f'{model_path_base}/Discrete_FullVB'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.VAE(adata, 
              tmax=20, 
              dim_z=5, 
              init_method='steady',
              device='cuda:0',
              discrete=True,
              full_vb=True)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed='umap')

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 901 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.51, 0.7331313053369634), (0.49, 0.34843595768045504)
(0.57, 0.812504473951968), (0.43, 0.44401169136605384)
(0.45, 0.7673639353636807), (0.55, 0.31774595674199757)
KS-test result: [0. 0. 0.]
Initial induction: 1164, repression: 836/2000
Learning Rate based on Data Sparsity: 0.0002
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 21, test iteration: 40
*********                      Stage  2                       *********
*********             Velocit

  0%|          | 0/3696 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.027
Average Set Size: 81
*********     Round 1: Early Stop Triggered at epoch 1197.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1407.    *********
Change in x0: 0.2289
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1596.    *********
Change in x0: 0.1404
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1798.    *********
Change in x0: 0.0922
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 1833.    *********
Change in x0: 0.0767
*********             Velocity Refinement Round 6             *********
*********     Round 6: Early Stop Triggered at epoch 2062.    *********
Change in x0: 0.0707
*********             Velocity Refinement Round 

# Evaluation

In [7]:
res, res_type = vv.post_analysis(adata,
                                 dataset,
                                 ['Discrete VeloVAE', 'Discrete FullVB'],
                                 ['dvae', 'dfullvb'],
                                 compute_metrics=False,
                                 raw_count=True,
                                 genes=gene_plot,
                                 grid_size=(1,4),
                                 figure_path=figure_path_base)

---   Plotting  Results   ---
computing velocity graph (using 1/32 cores)


  0%|          | 0/3696 [00:00<?, ?cells/s]

    finished (0:00:30) --> added 
    'dvae_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:00) --> added
    'dvae_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Pancreas/Pancreas_dvae_stream.png
computing velocity graph (using 1/32 cores)


  0%|          | 0/3696 [00:00<?, ?cells/s]

    finished (0:00:30) --> added 
    'dfullvb_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:00) --> added
    'dfullvb_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Pancreas/Pancreas_dfullvb_stream.png
