# OntoVAE workflow

In [None]:
# import modules
import os
from onto_vae.ontobj import *
from onto_vae.vae_model import *


## 1. Creation of Ontology object

First, we need to initialize an Ontobj. This is required by OntoVAE to initialize latent space and decoder and train the model. The Ontobj stores information about the used ontology, and also datasets that were matched to this ontology so that they can be used for trainind or passed through the VAE model.

### process ontology data

In [None]:
# initialize the Ontobj
# the description should be an identifier, e.g. the ontology used, here: PWO (Pathway Ontology)
pwo = Ontobj(description='PWO')

In [None]:
# initialize our ontology
# obo: path to an obo file
# gene_annot: path to a tab separated file with two columns: Genes and Ontology IDs
pwo.initialize_dag(obo=data_path() + 'pw.obo',
                   gene_annot=data_path() + 'gene_term_mapping.txt')

```
# fill the basic slots
self.annot_base = annot_updated
self.genes_base = sorted(list(set(gene_annot.Gene.tolist())))
term_dict.update(gene_term_dict)
self.graph_base = term_dict

```

In [None]:
pwo.annot_base

In [None]:
pwo.genes_base

In [None]:
pwo.graph

### Trim the ontology

In [None]:
# trim the ontology
pwo.trim_dag(top_thresh=1000, 
             bottom_thresh=30)

```
DAG is trimmed based on user-defined thresholds.
Trimmed version is saved in the graph, annot and genes slots.
# save trimming results in respective slots
self.annot[str(top_thresh) + '_' + str(bottom_thresh)] = new_annot
self.graph[str(top_thresh) + '_' + str(bottom_thresh)] = term_dict_trim
self.genes[str(top_thresh) + '_' + str(bottom_thresh)] = sorted(list(gene_trim.keys()))
self.desc_genes[str(top_thresh) + '_' + str(bottom_thresh)] = desc_genes

```

In [None]:
pwo.graph.keys()

In [None]:
pwo.graph['1000_30']

### Visualize the ontolology

In [None]:
ontology_dict = pwo.graph['1000_30']
visualize_ontology(ontology_dict,max_depth=5, sample_size=20)

### Creat binary masks

We make a list of binary mask for all possible depth combos. Each of binary mask is a binary matrix for every elements within that depth combo, `0` for child-parent relationship, `1` for no child-parent relationship.

In [None]:
# create masks for decoder initialization
pwo.create_masks(top_thresh=1000,
                 bottom_thresh=30)

In [None]:
decoder_mask = pwo.masks['1000_30']['decoder']


### Match datasets

In [None]:
# match a dataset to the ontology
# expr_path: path to the dataset (either h5ad)
pwo.match_dataset(expr_data = data_path() + 'pbmc_sample_expr.csv',
                  name='PBMC_CD4T')

## 2. OntoVAE model training

In [None]:
# initialize OntoVAE 
pwo_model = OntoVAE(ontobj=pwo,              # the Ontobj we will use
                    dataset='PBMC_CD4T',     # which dataset from the Ontobj to use for model training
                    top_thresh=1000,         # which trimmed version to use
                    bottom_thresh=30)        # which trimmed version to use     
pwo_model.to(pwo_model.device)         

In [None]:
# generate a directory where to store the best model
if not os.path.isdir(os.getcwd() + '/models'):
    os.mkdir(os.getcwd() + '/models')

In [None]:
# train the model
pwo_model.train_model(os.getcwd() + '/models/best_model.pt',   # where to store the best model
                     lr=1e-4,                                 # the learning rate
                     kl_coeff=1e-4,                           # the weighting coefficient for the Kullback Leibler loss
                     batch_size=128,                          # the size of the minibatches
                     epochs=5)                                # over how many epochs to train                               # whether run should be logged to Neptune

## 3. Analysis with OntoVAE model (pathway activities + perturbations)

We can use a trained OntoVAE model to retrieve pathway activities from latent space and decoder, but also to perform in silico perturbations prior to retrieving the pathway activities.

In [None]:
# load the best model
checkpoint = torch.load(os.getcwd() + '/models/best_model.pt',
                        map_location = torch.device(pwo_model.device), 
                        weights_only=True)
pwo_model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# retrieve pathway activities
pwo_act = pwo_model.get_pathway_activities(ontobj=pwo,
                                           dataset='PBMC_CD4T')

We can use a function of the Ontobj to make an example scatterplot for two pathways of our choice.

In [None]:
# make scatterplot for two pathway activities
pwo.plot_scatter(sample_annot = data_path() + 'pbmc_sample_annot.csv',   # pandas Dataframe or path to annotation file
                 color_by = 'condition',                                 # variable to use for coloring
                 act = pwo_act,                                          # pathway activities computed from OntoVAE model
                 term1 = 'interferon mediated signaling pathway',        # term on x-axis of scatter plot
                 term2 = 'T cell receptor signaling pathway',            # term on y-axis of scatter plot
                 top_thresh = 1000,                                      # which trimmed version to use
                 bottom_thresh = 30)                                     # which trimmed version to use

# Note that the scatterplot displayed in the vignette was trained over 200 epochs!

We can now perform in silico perturbations and Wilcoxon tests to see which terms are influenced most.

In [None]:
# get pathway activities where ISG15 was perturbed
pwo_ko_act = pwo_model.perturbation(ontobj=pwo,
                                    dataset='PBMC_CD4T',
                                    genes=['ISG15'],        # list of genes to be perturbed
                                    values=[0])             # list of new values for the genes

We can use a function of the Ontobj to perform a paired Wilcoxon test between perturbed and non-perturbed for all terms of the ontology and get a ranked dataframe

In [None]:
# perform paired Wilcoxon test
results = pwo.wilcox_test(control = pwo_act,
                          perturbed = pwo_ko_act,
                          direction = 'down',
                          top_thresh=1000,
                          bottom_thresh=30)

In [None]:
# display the top results
results.head(10)

We see that when we modulate ISG15 expression, our top hits are all pathways related to the immune response!