In [2]:
import anndata
import numpy as np
import scvelo as scv
import sys
import torch

sys.path.append('../')
import velovae as vv

# 1. Train a Vanilla VAE
## Load the input AnnData Object

In [4]:
filename = '../data/Pancreas/pancreas.h5ad'
adata = anndata.read_h5ad(filename)

## Preparation Work

In [5]:
Ngene = 2000
vv.preprocess(adata, 2000)

figure_path = '../figures/Pancreas/Default'
model_path = '../checkpoints/Pancreas/Default'
data_path = '../data/Pancreas/Default'

Filtered out 22645 genes that are detected 50 counts (shared).
Normalized count data: X, spliced, unspliced.
Exctracted 2000 highly variable genes.
Logarithmized X.
computing neighbors
    finished (0:00:03) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)


## Train the model

In [None]:
gene_plot = ['Pcsk2','Dcdc2a','Gng12','Cpe','Ppp3ca'] #some marker genes
config_vae = {
    'num_epochs':800, 
    'test_epoch':50, 
    'save_epoch':50, 
    'learning_rate':2e-4, 
    'learning_rate_ode':2e-4, 
    'lambda':1e-3, 
    'reg_t':2.0, 
    'batch_size':128
}

model = vv.VanillaVAE(adata, 20, hidden_size=(500, 250), tprior=None, device='gpu')
model.train(adata, config=config_vae, gene_plot=gene_plot, figure_path=figure_path)

## Save the model and AnnData with learned ODE parameters

In [None]:
model.saveModel(model_path)
model.saveAnnData(adata, 'vanilla', data_path, file_name='output.h5ad')

## Post-Analysis

In [None]:
adata = anndata.read_h5ad('../data/Pancreas/output.h5ad')
print(adata.var.keys())
vv.postAnalysis(adata, methods=['vanilla','scvelo'], keys=['vanilla','fit'], genes=['Cpe'], plot_type=['time','signal'])