In [1]:
import anndata
import numpy as np
import scanpy as sc
import sys
import torch
sys.path.append('../../../../')
import velovae as vv

In [2]:
dataset = 'Erythroid_Human'
root = '/scratch/blaauw_root/blaauw1/gyichen'
adata = anndata.read_h5ad(f'{root}/data/{dataset}_pp.h5ad')
adata.obs['clusters'] = adata.obs['type2'].to_numpy()

In [None]:
#vv.preprocess(adata, n_gene=2000, min_shared_counts=20)
#adata.write_h5ad(f'data/{dataset}_pp.h5ad')

In [4]:
model_path_base = f'{root}/checkpoints/{dataset}'
figure_path_base = f'{root}/figures/{dataset}'
data_path = f'{root}/data/velovae/discrete/{dataset}'
gene_plot = ['CNN3','CYR61','ABCG2','HBA2']

# Discrete VeloVAE

In [6]:
figure_path = f'{figure_path_base}/VeloVAE'
model_path = f'{model_path_base}/VeloVAE'

torch.manual_seed(2022)
np.random.seed(2022)
dvae = vv.VAE(adata, 
              tmax=20, 
              dim_z=5, 
              device='cuda:0',
              discrete=True)

dvae.train(adata, gene_plot=gene_plot, plot=False, figure_path=figure_path)
dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 992 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.66, 0.2137221315308991), (0.34, 0.6793733088533553)
(0.55, 0.7527810779600155), (0.45, 0.3355563953927338)
(0.58, 0.27398700344399946), (0.42, 0.7237131612612386)
KS-test result: [0. 0. 0.]
Initial induction: 892, repression: 1108/2000
Learning Rate based on Data Sparsity: 0.0001
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 197, test iteration: 392
*********       Stage 1: Early Stop Triggered at epoch 599.       *********
*********               

  0%|          | 0/35877 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.026
Average Set Size: 731
*********     Round 1: Early Stop Triggered at epoch 972.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1319.    *********
Change in x0: 0.0694
*********             Velocity Refinement Round 3             *********
Change in x0: 0.0381
*********             Velocity Refinement Round 4             *********
Change in x0: 0.0316
*********             Velocity Refinement Round 5             *********
Stage 2: Early Stop Triggered at round 4.
*********              Finished. Total Time =   2 h :  6 m : 19 s             *********
Final: Train ELBO = -2290.309,	Test ELBO = -2325.054


# Discrete Full VB

In [7]:
figure_path = f'{figure_path_base}/FullVB'
model_path = f'{model_path_base}/FullVB'

torch.manual_seed(2022)
np.random.seed(2022)
full_vb = vv.VAE(adata, 
                 tmax=20, 
                 dim_z=5, 
                 device='cuda:0',
                 discrete=True,
                 full_vb=True)

full_vb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

full_vb.save_model(model_path, 'encoder', 'decoder')
full_vb.save_anndata(adata, 'dfullvb', data_path, file_name=f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 992 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.66, 0.2137221315308991), (0.34, 0.6793733088533553)
(0.55, 0.7527810779600155), (0.45, 0.3355563953927338)
(0.58, 0.27398700344399946), (0.42, 0.7237131612612386)
KS-test result: [0. 0. 0.]
Initial induction: 892, repression: 1108/2000
Learning Rate based on Data Sparsity: 0.0001
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 197, test iteration: 392
*********                      Stage  2                       *********
*********             Veloci

  0%|          | 0/35877 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.027
Average Set Size: 725
*********     Round 1: Early Stop Triggered at epoch 1320.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1496.    *********
Change in x0: 0.0940
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1678.    *********
Change in x0: 0.0517
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1959.    *********
Change in x0: 0.0398
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 2270.    *********
Change in x0: 0.0364
*********             Velocity Refinement Round 6             *********
Stage 2: Early Stop Triggered at round 5.
*********              Finished. Total Time =   2 h : 16 m : 43 s             *********
Final: Tra

# Evaluation

In [8]:
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae','dfullvb'],
                 compute_metrics=False,
                 raw_count=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 figure_path=figure_path_base)

---   Plotting  Results   ---
computing velocity graph (using 17/32 cores)


  0%|          | 0/35877 [00:00<?, ?cells/s]

    finished (0:03:50) --> added 
    'dvae_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:10) --> added
    'dvae_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Erythroid_Human/eval_dvae_stream.png
computing velocity graph (using 17/32 cores)


  0%|          | 0/35877 [00:00<?, ?cells/s]

    finished (0:03:26) --> added 
    'dfullvb_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:10) --> added
    'dfullvb_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Erythroid_Human/eval_dfullvb_stream.png


(None, None)