# OntoVAE workflow

In [2]:
# import modules
import os
from onto_vae.ontobj import *
from onto_vae.vae_model import *


## 1. Creation of Ontology object

First, we need to initialize an Ontobj. This is required by OntoVAE to initialize latent space and decoder and train the model. The Ontobj stores information about:
   - The ontology used
  - datasets that were matched to this ontology so that they can be used for trainind or passed through the VAE model.

### Process ontology data

- processing `.obo` file:  
    - filter terms if needed
    - create the term annotate dataframe
    - creaet a graph_base: a dictionary with ontology relationships (children -> parents).
- processing `gene_term_mapping.txt` file: 
    - create gene base
    - update the term annotate dataframe by adding columns:
        - `descendants`: number of descendant terms
        - `desc_genes`: number of genes annotated to term and all its descendants
        - `genes`: number of genes directly annotated to term
    - update the graph_base: add gene keys in the dict, so the graph will include both genes and terms.

The processed results will be stored in the following slots in the initialized `Ontobj` object.

```python
# fill the basic slots
self.annot_base = annot_updated
self.genes_base = sorted(list(set(gene_annot.Gene.tolist())))
term_dict.update(gene_term_dict)
self.graph_base = term_dict

```

In [3]:
# initialize the Ontobj
# the description should be an identifier, e.g. the ontology used, here: PWO (Pathway Ontology)
pwo = Ontobj(description='PWO')

In [4]:
# initialize our ontology
# obo: path to an obo file
# gene_annot: path to a tab separated file with two columns: Genes and Ontology IDs
pwo.initialize_dag(obo=data_path() + 'pw.obo',
                   gene_annot=data_path() + 'gene_term_mapping.txt')

In [5]:
pwo.annot_base

Unnamed: 0,ID,Name,depth,children,parents,descendants,desc_genes,genes
0,PW:0000001,pathway,0,5.0,0,1593,5720,0.0
1,PW:0000024,inflammatory response pathway,0,3.0,0,9,251,5.0
2,PW:0000034,electron transport chain pathway,0,0.0,0,0,22,22.0
3,PW:0000060,long term potentiation,0,0.0,0,0,77,77.0
4,PW:0000061,long term depression,0,0.0,0,0,71,71.0
...,...,...,...,...,...,...,...,...
1647,PW:0002343,methylenetetrahydrofolate reductase deficiency...,8,0.0,2,0,19,19.0
1648,PW:0002431,altered retromer-mediated pathway,8,0.0,2,0,2,2.0
1649,PW:0002532,metachromatic leukodystrophy pathway,8,0.0,2,0,23,23.0
1650,PW:0000773,aldosterone biosynthetic pathway,9,0.0,1,0,8,8.0


In [6]:
genes_base = pwo.genes_base

In [7]:
graph_base = pwo.graph_base

### Trim the ontology

DAG is trimmed based on user-defined thresholds. Trimmed version is saved in the graph, annot and genes slots.
```python
# save trimming results in respective slots
self.annot[str(top_thresh) + '_' + str(bottom_thresh)] = new_annot
self.graph[str(top_thresh) + '_' + str(bottom_thresh)] = term_dict_trim
self.genes[str(top_thresh) + '_' + str(bottom_thresh)] = sorted(list(gene_trim.keys()))
self.desc_genes[str(top_thresh) + '_' + str(bottom_thresh)] = desc_genes

```

In [8]:
# trim the ontology
pwo.trim_dag(top_thresh=1000, 
             bottom_thresh=30)

In [9]:
pwo.graph

{'1000_30': {'PW:0000095': ['PW:0000094'],
  'PW:0000556': ['PW:0000555'],
  'PW:0000926': ['PW:0000024'],
  'PW:0000925': ['PW:0000024', 'PW:0000819'],
  'PW:0000303': ['PW:0000095'],
  'PW:0000043': ['PW:0000556', 'PW:0000005'],
  'PW:0000640': ['PW:0000556', 'PW:0000025'],
  'PW:0000025': ['PW:0000005'],
  'PW:0000041': ['PW:0000005'],
  'PW:0000055': ['PW:0000005'],
  'PW:0000064': ['PW:0000005'],
  'PW:0000065': ['PW:0000005'],
  'PW:0000146': ['PW:0000005'],
  'PW:0000151': ['PW:0000005'],
  'PW:0000152': ['PW:0000005'],
  'PW:0000154': ['PW:0000005'],
  'PW:0000188': ['PW:0000005'],
  'PW:0000533': ['PW:0000005'],
  'PW:0000641': ['PW:0000025', 'PW:0000685'],
  'PW:0000919': ['PW:0000754'],
  'PW:0000920': ['PW:0000754'],
  'PW:0000921': ['PW:0000754'],
  'PW:0000922': ['PW:0000754'],
  'PW:0000923': ['PW:0000754'],
  'PW:0000924': ['PW:0000754'],
  'PW:0001170': ['PW:0000754'],
  'PW:0001202': ['PW:0000754'],
  'PW:0002159': ['PW:0000754'],
  'PW:0002281': ['PW:0000754'],
  'PW

In [10]:
pwo.annot

{'1000_30':              ID                                               Name  depth  \
 0    PW:0000005                     carbohydrate metabolic pathway      0   
 1    PW:0000010                            lipid metabolic pathway      0   
 2    PW:0000011                       amino acid metabolic pathway      0   
 3    PW:0000012                       nucleotide metabolic pathway      0   
 4    PW:0000020              cardiovascular system disease pathway      0   
 ..          ...                                                ...    ...   
 652  PW:0000568                      aldosterone signaling pathway      7   
 653  PW:0001591                        xanthinuria  type I pathway      7   
 654  PW:0001592                        xanthinuria type II pathway      7   
 655  PW:0001752          pyruvate decarboxylase deficiency pathway      7   
 656  PW:0002576  3-methylcrotonyl CoA carboxylase 1 deficiency ...      7   
 
      children  parents  descendants  desc_genes  g

In [11]:
pwo.genes

{'1000_30': ['A2M',
  'A4GALT',
  'AACS',
  'AADAT',
  'AANAT',
  'AAR2',
  'AARS1',
  'AARS2',
  'AASDHPPT',
  'AASS',
  'ABAT',
  'ABCA1',
  'ABCA3',
  'ABCA4',
  'ABCA5',
  'ABCB1',
  'ABCB11',
  'ABCB4',
  'ABCC1',
  'ABCC10',
  'ABCC2',
  'ABCC3',
  'ABCC4',
  'ABCC5',
  'ABCC6',
  'ABCC8',
  'ABCC9',
  'ABCD1',
  'ABCD2',
  'ABCG2',
  'ABCG5',
  'ABCG8',
  'ABI1',
  'ABL1',
  'ABL2',
  'ABO',
  'ACAA1',
  'ACAA2',
  'ACACA',
  'ACACB',
  'ACAD8',
  'ACAD9',
  'ACADL',
  'ACADM',
  'ACADS',
  'ACADSB',
  'ACADVL',
  'ACAP1',
  'ACAP2',
  'ACAP3',
  'ACAT1',
  'ACAT2',
  'ACE',
  'ACE2',
  'ACER1',
  'ACER2',
  'ACER3',
  'ACHE',
  'ACIN1',
  'ACKR1',
  'ACLY',
  'ACMSD',
  'ACO1',
  'ACO2',
  'ACOT12',
  'ACOX2',
  'ACP1',
  'ACP2',
  'ACP3',
  'ACP4',
  'ACP5',
  'ACP6',
  'ACSL1',
  'ACSL3',
  'ACSL4',
  'ACSL5',
  'ACSL6',
  'ACSM1',
  'ACSM2A',
  'ACSM2B',
  'ACSM3',
  'ACSM4',
  'ACSM5',
  'ACSS1',
  'ACSS2',
  'ACSS3',
  'ACTA2',
  'ACTB',
  'ACTC1',
  'ACTG1',
  'ACTL6A',
 

### Visualize the ontolology

In [12]:
ontology_dict = pwo.graph['1000_30']
visualize_ontology(ontology_dict,max_depth=5,sample_size=20)

ontology_hierarchy.html


### Create masks for decoder 

We make a list of binary mask for all possible depth combos. Each of binary mask is a binary matrix for every elements within that depth combo.

In [13]:
# create masks for decoder initialization
pwo.create_masks(top_thresh=1000,
                 bottom_thresh=30)

In [14]:
pwo.masks['1000_30']['decoder']

[array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=int64),
 array([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

### Match dataset

Match the ontology genes with the genes in the expression data.

In [15]:
# match a dataset to the ontology
# expr_path: path to the dataset (either h5ad)
pwo.match_dataset(expr_data = data_path() + 'pbmc_sample_expr.csv',
                  name='PBMC_CD4T')

## 2. OntoVAE model training

In [None]:
# initialize OntoVAE 
pwo_model = OntoVAE(ontobj=pwo,              # the Ontobj we will use
                    dataset='PBMC_CD4T',     # which dataset from the Ontobj to use for model training
                    top_thresh=1000,         # which trimmed version to use
                    bottom_thresh=30)        # which trimmed version to use     
pwo_model.to(pwo_model.device)         

In [None]:
# generate a directory where to store the best model
if not os.path.isdir(os.getcwd() + '/models'):
    os.mkdir(os.getcwd() + '/models')

In [None]:
# train the model
pwo_model.train_model(os.getcwd() + '/models/best_model.pt',   # where to store the best model
                     lr=1e-4,                                 # the learning rate
                     kl_coeff=1e-4,                           # the weighting coefficient for the Kullback Leibler loss
                     batch_size=128,                          # the size of the minibatches
                     epochs=5)                                # over how many epochs to train                               # whether run should be logged to Neptune

## 3. Analysis with OntoVAE model (pathway activities + perturbations)

We can use a trained OntoVAE model to retrieve pathway activities from latent space and decoder, but also to perform in silico perturbations prior to retrieving the pathway activities.

In [None]:
# load the best model
checkpoint = torch.load(os.getcwd() + '/models/best_model.pt',
                        map_location = torch.device(pwo_model.device), 
                        weights_only=True)
pwo_model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# retrieve pathway activities
pwo_act = pwo_model.get_pathway_activities(ontobj=pwo,
                                           dataset='PBMC_CD4T')

We can use a function of the Ontobj to make an example scatterplot for two pathways of our choice.

In [None]:
# make scatterplot for two pathway activities
pwo.plot_scatter(sample_annot = data_path() + 'pbmc_sample_annot.csv',   # pandas Dataframe or path to annotation file
                 color_by = 'condition',                                 # variable to use for coloring
                 act = pwo_act,                                          # pathway activities computed from OntoVAE model
                 term1 = 'interferon mediated signaling pathway',        # term on x-axis of scatter plot
                 term2 = 'T cell receptor signaling pathway',            # term on y-axis of scatter plot
                 top_thresh = 1000,                                      # which trimmed version to use
                 bottom_thresh = 30)                                     # which trimmed version to use

# Note that the scatterplot displayed in the vignette was trained over 200 epochs!

We can now perform in silico perturbations and Wilcoxon tests to see which terms are influenced most.

In [None]:
# get pathway activities where ISG15 was perturbed
pwo_ko_act = pwo_model.perturbation(ontobj=pwo,
                                    dataset='PBMC_CD4T',
                                    genes=['ISG15'],        # list of genes to be perturbed
                                    values=[0])             # list of new values for the genes

We can use a function of the Ontobj to perform a paired Wilcoxon test between perturbed and non-perturbed for all terms of the ontology and get a ranked dataframe

In [None]:
# perform paired Wilcoxon test
results = pwo.wilcox_test(control = pwo_act,
                          perturbed = pwo_ko_act,
                          direction = 'down',
                          top_thresh=1000,
                          bottom_thresh=30)

In [None]:
# display the top results
results.head(10)

We see that when we modulate ISG15 expression, our top hits are all pathways related to the immune response!