In [1]:
import anndata
import numpy as np
import sys
import torch
sys.path.append('../../../')
import velovae as vv

In [2]:
dataset = 'Braindev'
root = "/scratch/blaauw_root/blaauw1/gyichen"
adata = anndata.read_h5ad(f'{root}/data/Braindev_pp.h5ad')
#adata = anndata.read_h5ad(f'{root}/data/velovae/continuous/Braindev/Braindev.h5ad')
adata.obs['clusters'] = adata.obs['Class'].to_numpy()

In [None]:
vv.preprocess(adata, n_gene=1000, min_shared_counts=20, compute_umap=True, keep_raw=True)

In [3]:
model_path_base = f'{root}/checkpoints/{dataset}'
figure_path_base = f'{root}/figures/{dataset}'
data_path = f'{root}/data/velovae/discrete/{dataset}'
gene_plot = ['Mapt', 'Tmsb10', 'Fabp7', 'Npm1']

# Informative Time Prior

In [4]:
capture_time = adata.obs['Age'].to_numpy()
tprior = np.array([float(x[1:]) for x in capture_time])
adata.obs['tprior'] = tprior

## Discrete VAE

In [6]:
figure_path = f'{figure_path_base}/DVAE'
model_path = f'{model_path_base}/DVAE'


torch.manual_seed(2022)
np.random.seed(2022)

dvae = vv.VAE(adata, 
              tmax=20, 
              dim_z=5, 
              device='cuda:0',
              init_method='tprior',
              init_key='tprior',
              tprior='tprior',
              discrete=True,
              init_ton_zero=False)

dvae.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dvae.save_model(model_path, 'encoder', 'decoder')
dvae.save_anndata(adata, 'dvae', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 5, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 1218 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using prior time.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Initial induction: 1733, repression: 267/2000
Using informative time prior.
Learning Rate based on Data Sparsity: 0.0009
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********


  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.028
Average Set Size: 603
*********     Round 1: Early Stop Triggered at epoch 1279.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 1731.    *********
Change in x0: 0.1918
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1811.    *********
Change in x0: 0.1339
*********             Velocity Refinement Round 4             *********
Change in x0: 0.1235
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 2348.    *********
Change in x0: 0.1262
*********             Velocity Refinement Round 6             *********
Stage 2: Early Stop Triggered at round 5.
*********              Finished. Total Time =   1 h : 36 m : 26 s             *********
Final: Train ELBO = -1256.971,	Test ELBO = -1272.793


# Discrete FullVB

In [7]:
figure_path = f'{figure_path_base}/DFullVB'
model_path = f'{model_path_base}/DFullVB'

torch.manual_seed(2022)
np.random.seed(2022)
dfullvb = vv.VAE(adata, 
                 tmax=20, 
                 dim_z=5, 
                 device='cuda:0',
                 init_method='tprior',
                 init_key='tprior',
                 tprior='tprior',
                 discrete=True,
                 full_vb=True,
                 init_ton_zero=False)

dfullvb.train(adata, plot=False, gene_plot=gene_plot, figure_path=figure_path)

dfullvb.save_model(model_path, 'encoder', 'decoder')
dfullvb.save_anndata(adata, 'dfullvb', data_path, f'{dataset}.h5ad')

Detecting zero scaling factors: 5, 0
Estimating ODE parameters...


  0%|          | 0/2000 [00:00<?, ?it/s]

Detected 1218 velocity genes.
Estimating the variance...


  0%|          | 0/2000 [00:00<?, ?it/s]

Initialization using prior time.
Reinitialize the regular ODE parameters based on estimated global latent time.


  0%|          | 0/2000 [00:00<?, ?it/s]

Initial induction: 1733, repression: 267/2000
Using informative time prior.
Learning Rate based on Data Sparsity: 0.0009
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 164, test iteration: 326
*********       Stage 1: Early Stop Triggered at epoch 750.       *********
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********


  0%|          | 0/29948 [00:00<?, ?it/s]

Percentage of Invalid Sets: 0.028
Average Set Size: 602
*********     Round 1: Early Stop Triggered at epoch 850.    *********
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 946.    *********
Change in x0: 0.1948
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 1006.    *********
Change in x0: 0.1368
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 1142.    *********
Change in x0: 0.1442
*********             Velocity Refinement Round 5             *********
Stage 2: Early Stop Triggered at round 4.
*********              Finished. Total Time =   0 h : 53 m : 43 s             *********
Final: Train ELBO = -1338.947,	Test ELBO = -1346.175


# Evaluation

In [8]:
cluster_edges = [('Neural tube','Radial glia'),
                 ('Radial glia', 'Neuroblast'),
                 ('Radial glia', 'Glioblast'),
                 ('Radial glia', 'Oligodendrocyte'),
                 ('Radial glia', 'Ependymal'),
                 ('Neural crest', 'Mesenchyme'),
                 ('Mesenchyme','Fibroblast')]
vv.post_analysis(adata,
                 'eval',
                 ['Discrete VeloVAE','Discrete FullVB'],
                 ['dvae','dfullvb'],
                 compute_metrics=False,
                 genes=gene_plot,
                 grid_size=(1,4),
                 figure_path=figure_path_base,
                 cluster_edges=cluster_edges)

---   Plotting  Results   ---
computing velocity graph (using 14/32 cores)


  0%|          | 0/29948 [00:00<?, ?cells/s]

    finished (0:01:20) --> added 
    'dvae_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:08) --> added
    'dvae_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Braindev/eval_dvae_stream.png
computing velocity graph (using 14/32 cores)


  0%|          | 0/29948 [00:00<?, ?cells/s]

    finished (0:01:06) --> added 
    'dfullvb_velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing velocity embedding
    finished (0:00:08) --> added
    'dfullvb_velocity_umap', embedded velocity vectors (adata.obsm)
saving figure to file /scratch/blaauw_root/blaauw1/gyichen/figures/Braindev/eval_dfullvb_stream.png


(None, None)