In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'BMMC'
root = '/scratch/blaauw_root/blaauw1/gyichen'
adata = anndata.read_h5ad(f'{root}/data/BMMC_pp.h5ad')
adata.obs["clusters"] = adata.obs['celltype.l2'].to_numpy().astype('U15')

In [None]:
vv.preprocess(adata, n_gene=2000, min_shared_counts=20, compute_umap=True)

In [3]:
model_path_base = f'{root}/checkpoints/{dataset}'
figure_path_base = f'{root}/figures/{dataset}'
data_path = f'{root}/data/velovae/discrete/{dataset}'
gene_plot = ['SPINK2', 'AZU1', 'MPO', 'LYZ', 'CD74', 'HBB']

# Discrete VAE

In [4]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.VAE(adata, 
                    tmax=20, 
                    dim_z=30,
                    device='cuda:0',
                    discrete=True)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 717 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.59, 0.732014521107534), (0.41, 0.3886589322687619)
(0.68, 0.7303502936602487), (0.32, 0.2898126955613042)
KS-test result: [0. 0. 1.]
Initial induction: 1377, repression: 623/2000
Learning Rate based on Data Sparsity: 0.0012
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 121, test iteration: 240
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********


  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.028
Average Set Size: 447
*********     Round 1: Early Stop Triggered at epoch 1270.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1296.    *********
Change in x0: 0.1568
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1384.    *********
Change in x0: 0.0847
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1424.    *********
Change in x0: 0.1459
*********             Velocity Refinement Round 5             *********
Stage 2: Early Stop Triggered at round 4.
*********              Finished. Total Time =   0 h : 45 m : 44 s             *********
Final: Train ELBO = -1066.183,	Test ELBO = -1079.890


# FullVB

In [5]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.model.VAE(adata, 
                    tmax=20, 
                    dim_z=30,
                    device='cuda:0',
                    discrete=True,
                    full_vb=True)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path, embed="umap")

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 0, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 717 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

3 clusters detected based on gene co-expression.
(0.59, 0.732014521107534), (0.41, 0.3886589322687619)
(0.68, 0.7303502936602487), (0.32, 0.2898126955613042)
KS-test result: [0. 0. 1.]
Initial induction: 1377, repression: 623/2000
Learning Rate based on Data Sparsity: 0.0012
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 121, test iteration: 240
*********       Stage 1: Early Stop Triggered at epoch 308.       *********
*********                      Stage  2                       *********
*********

  0%|          | 0/22122 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.025
Average Set Size: 447
*********     Round 1: Early Stop Triggered at epoch 388.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 543.    *********
Change in x0: 0.2262
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 561.    *********
Change in x0: 0.1410
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 597.    *********
Change in x0: 0.1241
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 655.    *********
Change in x0: 0.1169
*********             Velocity Refinement Round 6             *********
Stage 2: Early Stop Triggered at round 5.
*********              Finished. Total Time =   0 h : 25 m : 13 s             *********
Final: Train EL

# Evaluation

In [11]:
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE', 'Discrete FullVB'],
                 ['dvae', 'dfullvb'],
                 compute_metrics=False,
                 raw_count=True,
                 genes=gene_plot,
                 grid_size=(1,4),
                 gene_plot=gene_plot,
                 figure_path=figure_path_base)

---   Plotting  Results   ---
computing velocity graph (using 11/32 cores)


  0%|          | 0/22122 [00:00<?, ?cells/s]

    finished (0:01:35) --> added 
    'dvae_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:06) --> added
    'dvae_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/BMMC/eval_dvae_stream.png
computing velocity graph (using 11/32 cores)


  0%|          | 0/22122 [00:00<?, ?cells/s]

    finished (0:01:26) --> added 
    'dfullvb_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:06) --> added
    'dfullvb_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/BMMC/eval_dfullvb_stream.png


(None, None)