# Mutual Interactors 

*Mutual Interactors* is a machine learning algorithm for node set expansion in large networks. The algorithm is motivated by the structure of disease-associated proteins, drug targets and protein functions in molecular networks, and can be used to  predict molecular phenotypes *in silico*. For a detailed description of the algorithm, please see our [paper](TODO).  

In this notebook, we'll walk through how we train a *Mutual Interactors* model to predict novel disease protein associations. We use a PPI network and a large set of disease-protein associations to train the model. 

Although this notebook uses a PPI network and disease protein associations, it can easily be retrofitted to work with any network and any node set type. 

In [1]:
%load_ext autoreload
%autoreload 2

import os

import networkx as nx

from milieu.data.network import Network
from milieu.data.associations import load_diseases
from milieu.util.util import load_mapping
from milieu.milieu import MilieuDataset, Milieu
from milieu.paper.figures.network_vis import show_network

# change director to root directory of package
root_dir = os.path.dirname(os.getcwd())
os.chdir(root_dir)

## Load a Network
To use *Mutual Interactors* we need a network! 

We'll use the human protein-protein interaction network compiled by Menche *et al.*[1]. The network consists of 342,353 interactions between 21,557 proteins. Se
In `data/networks`, you can find this network `bio-pathways-network.txt`. See methods for a more detailed description of the network. 

We use the class `milieu.data.network.Network` to load and represent networks. The constructor accepts a path to an edge list.

In [4]:
network = Network("data/networks/species_9606/bio-pathways/network.txt")

## Build the *Milieu* Model

The *Mutual Interactors* is parameterized by a few important hyperparameters.

We find that learning rate parameter (i.e. `optim_args/lr` in the nested dictionary below) can have significant impact on performance. The optimal value varies substantially between networks and applications, so we recommend tuning it. 

If you have a GPU available, setting `cuda` to `True` and specifying an available `device` should speed up training considerably. That being said, training *Mutual Interactors* is usually tractable on CPU for networks with $n <= 30k$. 

In [4]:
params = {
    "cuda": False,
    "device": 2,
    
    "batch_size": 200,
    "num_workers": 4,
    "num_epochs": 10,
    
    "optim_class": "Adam",
    "optim_args": {
        "lr": 0.01,
        "weight_decay": 0.0
    },
    
    "metric_configs": [
        {
            "name": "recall_at_25",
            "fn": "batch_recall_at", 
            "args": {"k":25}
        }
    ]
}

We've implemented the *Mutual Interactors* model in a self-contained class `milieu.milieu.Milieu`. This class contains methods for training the model `Milieu.train_model`, evaluating the model on a test set `Milieu.score` and predicting node set expansions `Milieu.expand`. 

The constructor accepts the network and the dictionary of params we defined above. 


In [5]:
milieu = Milieu(network, params)

Milieu
Setting parameters...
Building model...
Building optimizer...
Done.


## Train the Model
*Mutual Interactors* is trained on a dataset of groups of nodes known to be associated with one another in some way.  
In this example, we use sets of proteins associated with the same disease. Our disease-protein associations come from disgenet and are found at `data/disease_associations/disgenet-associations.csv`. 

We load the disease-protein associations with `milieu.data.associations.load_diseases` which returns a list of `milieu.data.associations.NodeSet`. Each NodeSet represents the set of proteins associated with on disease.

To evaluate the model as we train it, we'll split the set of diseases into train set and a validation set. Next, we'll create a `milieu.milieu.MilieuDataset` for each. A `MilieuDataset` is simply a PyTorch dataset that creates training examples for the *Mutual Interactors* momdel. 

In [5]:
node_sets = list(load_diseases("data/associations/disgenet/associations.csv", exclude_splits=["none"]).values())
train_node_sets = node_sets[:int(len(node_sets)* 0.9)]
valid_node_sets = node_sets[int(len(node_sets)* 0.9):]
train_dataset = MilieuDataset(network, node_sets=train_node_sets)
valid_dataset = MilieuDataset(network, node_sets=valid_node_sets)

In [11]:
milieu.train_model(train_dataset, valid_dataset)

Starting training for 10 epoch(s)
Epoch 1 of 10
Training


100%|██████████| 9/9 [00:26<00:00,  2.51s/it, loss=1.584]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.73s/it]

Epoch 2 of 10
Training



100%|██████████| 9/9 [00:28<00:00,  2.55s/it, loss=1.503]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.41s/it]

Epoch 3 of 10
Training



100%|██████████| 9/9 [00:25<00:00,  2.36s/it, loss=1.441]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.34s/it]

Epoch 4 of 10
Training



100%|██████████| 9/9 [00:25<00:00,  2.38s/it, loss=1.389]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.38s/it]

Epoch 5 of 10
Training



100%|██████████| 9/9 [00:26<00:00,  2.46s/it, loss=1.350]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.39s/it]

Epoch 6 of 10
Training



100%|██████████| 9/9 [00:26<00:00,  2.43s/it, loss=1.325]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.36s/it]

Epoch 7 of 10
Training



100%|██████████| 9/9 [00:26<00:00,  2.48s/it, loss=1.303]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.39s/it]

Epoch 8 of 10
Training



100%|██████████| 9/9 [00:27<00:00,  2.55s/it, loss=1.287]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.45s/it]

Epoch 9 of 10
Training



100%|██████████| 9/9 [00:28<00:00,  2.77s/it, loss=1.277]

Validation



100%|██████████| 1/1 [00:03<00:00,  3.18s/it]

Epoch 10 of 10
Training



100%|██████████| 9/9 [00:32<00:00,  2.91s/it, loss=1.264]

Validation



100%|██████████| 1/1 [00:02<00:00,  2.71s/it]


([{'recall_at_25': 0.05622317725766016},
  {'recall_at_25': 0.057041638494999596},
  {'recall_at_25': 0.06314191199974062},
  {'recall_at_25': 0.06293379194140619},
  {'recall_at_25': 0.07390836548325869},
  {'recall_at_25': 0.06672080551383561},
  {'recall_at_25': 0.06160882956753036},
  {'recall_at_25': 0.06786463317964866},
  {'recall_at_25': 0.06509961636312872},
  {'recall_at_25': 0.0671994546838794}],
 [defaultdict(list, {'recall_at_25': [0.04652185421416191]}),
  defaultdict(list, {'recall_at_25': [0.04717101681387396]}),
  defaultdict(list, {'recall_at_25': [0.05337757480614624]}),
  defaultdict(list, {'recall_at_25': [0.05485644103342195]}),
  defaultdict(list, {'recall_at_25': [0.0822745341542334]}),
  defaultdict(list, {'recall_at_25': [0.061185832889129585]}),
  defaultdict(list, {'recall_at_25': [0.05309115139942208]}),
  defaultdict(list, {'recall_at_25': [0.08251139885755271]}),
  defaultdict(list, {'recall_at_25': [0.0589163447246154]}),
  defaultdict(list, {'recall_at_

## Predict Novel Associations
Now that we've got a trained *Mutual Interactors* model, we can use it to expand some node sets!

In particular, here we are going to use it to predict which proetins are associated with Tracheomalacia, a condition characterized by flaccidity of the supporting tracheal cartilage. 

To do so, we specify the set of proteins associated with Tracheomalacia using GenBank IDs. 

In [309]:
# Specify a set of proteins by their GenBank IDs
# For example, we use the proteins associated with Tracheomalacia
# Swap out these GenBank IDs for another set of proteins! 
tracheomalacia_proteins = ['COL2A1', 'HRAS', 'DCHS1', 'SNRPB', 'ORC4', 'LTBP4', 
                           'FLNB', 'PRRX1', 'RAB3GAP2', 'FGFR2','TRIM2']

In [310]:
# Convert genbank ids to entrez ids, since our network uses entrez ids
genbank_to_entrez = load_mapping("data/protein_attrs/genbank_to_entrez.txt",
                                 b_transform=int, delimiter='\t')
tracheomalacia_entrez = [genbank_to_entrez[protein] for protein in tracheomalacia_proteins]

In [300]:
# Expand the set of proteins using our trained model! 
# Change the number of predicted proteins using the top_k parameter
predicted_entrez = milieu.expand(node_names=tracheomalacia_entrez, top_k=5)
predicted_entrez = list(zip(*predicted_entrez))[0]

Using the function `milieu.paper.figures.network_vis.show_network` we can generate a Cytoscape visualization of the predictions!

In [302]:
# Generate a network visualization with cytoscape 
# Note: it is recommended to limit the size of the visualization to ~250 nodes  
cy_vis = show_network(network, tracheomalacia_entrez, predicted_entrez, id_format="entrez",
                      model=milieu,
                      show_seed_mi=True, excluded_interactions=[("mutual_interactor", "mutual_interactor")],
                      size_limit=250)

In [303]:
# Show the visualization!
# Red nodes are the seed nodes fed to the momdel. 
# Orange nodes are predicted nodes. Blue nodes are the interactors between them. 
cy_vis

Cytoscape(data={'elements': {'nodes': [{'data': {'role': 'seed', 'id': '925', 'entrez': '1280', 'genbank': 'CO…

1. Menche, J. et al. Uncovering disease-disease relationships through the incomplete interactome. Science 347, 1257601–1257601 (2015).
2.