In [1]:
import pandas as pd
import networkx as nx
import pickle
import copy

from gensim.models import KeyedVectors

import torch
import torch.nn as nn
import torch.nn.functional as F

from deepsnap.graph import Graph

from torch_geometric.nn import SAGEConv
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig


In [2]:
dataset_nr = 2

# Load Necessary Data

For the node features, load the embeddings yielded from Edge2Vec.

In [60]:
loaded_node_embedding = KeyedVectors.load(f'output/w2v_{dataset_nr}.dvectors', mmap='r')

node_embeddings = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(loaded_node_embedding.index_to_key):
    node_embeddings.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(loaded_node_embedding[key])})
    
node_features = node_embeddings.sort_values('Node')

node_features

Unnamed: 0,Node,Embedding
0,0,"[0.4429384, -0.64671946, -0.045503706, 0.04694..."
1,1,"[0.7495496, -0.5790302, 0.16192597, 0.03131027..."
2,2,"[-0.21050918, -0.28908953, 0.118277326, -0.162..."
3,3,"[0.2527442, -0.2967686, 0.29164207, 0.8072669,..."
4,4,"[0.41442072, 0.06997513, 0.4540072, -0.0370618..."
...,...,...
10270,10270,"[0.21862394, 0.11319974, 0.31551972, 0.1497048..."
10271,10271,"[0.8747792, -0.7461259, 0.22221777, -0.0425329..."
10272,10272,"[0.5863251, -0.8054981, 0.27565393, 0.516578, ..."
10273,10273,"[0.6474387, -0.38251266, 0.43746328, 0.3413055..."


For additional information of nodes, load csv file.

In [61]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,MP:0004187,phenotype,cardia bifida,9
1,1,ZP:0100138,phenotype,muscle tendon junction myotome increased amoun...,9
2,2,MGI:1346525,gene,Sgcd,5
3,3,OMIM:300377.0044,variant,"DMD, LYS770TER",11
4,4,ZP:0002210,phenotype,posterior lateral line neuromast primordium mi...,9
...,...,...,...,...,...
10270,10270,ZP:0014934,phenotype,atrioventricular valve development process qua...,9
10271,10271,ENSEMBL:ENSCAFG00000011207,gene,ENSEMBL:ENSCAFG00000011207,5
10272,10272,ENSEMBL:ENSXETG00000039922,gene,ENSEMBL:ENSXETG00000039922,5
10273,10273,ENSEMBL:ENSACAG00000010058,gene,ENSEMBL:ENSACAG00000010058,5


In [62]:
node_features = pd.merge(node_features, nodes,
                         left_on='Node', right_on='index_id',
                         how='inner')
node_features

Unnamed: 0,Node,Embedding,index_id,id,semantic,label,semantic_id
0,0,"[0.4429384, -0.64671946, -0.045503706, 0.04694...",0,MP:0004187,phenotype,cardia bifida,9
1,1,"[0.7495496, -0.5790302, 0.16192597, 0.03131027...",1,ZP:0100138,phenotype,muscle tendon junction myotome increased amoun...,9
2,2,"[-0.21050918, -0.28908953, 0.118277326, -0.162...",2,MGI:1346525,gene,Sgcd,5
3,3,"[0.2527442, -0.2967686, 0.29164207, 0.8072669,...",3,OMIM:300377.0044,variant,"DMD, LYS770TER",11
4,4,"[0.41442072, 0.06997513, 0.4540072, -0.0370618...",4,ZP:0002210,phenotype,posterior lateral line neuromast primordium mi...,9
...,...,...,...,...,...,...,...
10270,10270,"[0.21862394, 0.11319974, 0.31551972, 0.1497048...",10270,ZP:0014934,phenotype,atrioventricular valve development process qua...,9
10271,10271,"[0.8747792, -0.7461259, 0.22221777, -0.0425329...",10271,ENSEMBL:ENSCAFG00000011207,gene,ENSEMBL:ENSCAFG00000011207,5
10272,10272,"[0.5863251, -0.8054981, 0.27565393, 0.516578, ...",10272,ENSEMBL:ENSXETG00000039922,gene,ENSEMBL:ENSXETG00000039922,5
10273,10273,"[0.6474387, -0.38251266, 0.43746328, 0.3413055...",10273,ENSEMBL:ENSACAG00000010058,gene,ENSEMBL:ENSACAG00000010058,5


For the edges, load csv file.

In [63]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00003929,pat-2,5,1542,0
1,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00006789,unc-54,5,6544,0
2,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,5,9268,1
3,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,5,5387,1
4,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,5,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,4,5945,targets,P11229,Muscarinic acetylcholine receptor M1,6,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",11,1578,is allele of,HGNC:2928,DMD,5,3310,15
85989,5297,dacomitinib,4,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,6,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,11,8189,has affected feature,HGNC:2928,DMD,5,3310,11


Load nodes with their features and edges into graph.

In [64]:
G = nx.DiGraph()
for ind, node in node_features.iterrows(): 
    G.add_node(int(node['Node']), node_feature = torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']))
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 0.4429, -0.6467, -0.0455,  0.0469,  0.3007, -0.4822,  0.2423,  0.0707,
        -0.5349,  0.4175,  0.7552,  0.0628, -0.3813, -0.3389,  0.5671,  0.4919,
         0.0854,  0.0053, -0.2314,  0.8349,  0.5443,  0.7359, -0.1399, -0.2049,
        -0.1793,  0.0314, -0.7152,  0.1764,  0.3586, -0.4770,  0.0838, -0.3013,
         0.1751, -0.7785, -0.1280,  0.4082,  0.5959,  0.4058,  0.5182,  0.5561,
         0.3995, -0.1704,  0.3317, -0.1183, -0.1047,  0.2789,  0.3293, -0.1368,
         0.1259,  0.1704,  0.7724,  0.4460,  0.6070, -0.3315,  0.5776,  0.4367,
         0.2875, -0.7309,  0.2535,  0.5972,  0.4410, -0.4248, -0.3355,  0.3258])})
(2, 7693, {})
Number of edges is 85878
Number of nodes is 10275


In [65]:
G_attributes = nx.DiGraph()
for ind, node in node_features.iterrows(): 
    G_attributes.add_node(int(node['Node']), node_label = node['label'], node_semantic = node['semantic'])
for ind, edge in edges.iterrows(): 
    G_attributes.add_edge(int(edge['index_head']), int(edge['index_tail']), edge_label = edge['relation'])
  
for node in G_attributes.nodes(data=True):
    print(node)
    break
for edge in G_attributes.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G_attributes.number_of_edges()))
print("Number of nodes is {}".format(G_attributes.number_of_nodes()))

(0, {'node_label': 'cardia bifida', 'node_semantic': 'phenotype'})
(2, 7693, {'edge_label': 'found in'})
Number of edges is 85878
Number of nodes is 10275


In [66]:
DeepG = Graph(G)

print(DeepG.node_feature.shape)
print(DeepG.edge_index.shape)

torch.Size([10275, 64])
torch.Size([2, 85878])


In [67]:
x = DeepG.node_feature
x

tensor([[ 0.4429, -0.6467, -0.0455,  ..., -0.4248, -0.3355,  0.3258],
        [ 0.7495, -0.5790,  0.1619,  ..., -0.4995,  0.4305,  0.0857],
        [-0.2105, -0.2891,  0.1183,  ...,  0.3423, -0.0471,  0.2484],
        ...,
        [ 0.5863, -0.8055,  0.2757,  ..., -0.5787, -0.5317, -0.0063],
        [ 0.6474, -0.3825,  0.4375,  ...,  0.3179, -0.5550, -0.1071],
        [ 0.5623, -0.5136, -0.0366,  ..., -0.7215, -0.4545, -0.0579]])

Load arguments and parameters of trained GNN model.

In [68]:
with open(f'output/best_model_{dataset_nr}_args.pkl', 'rb') as f:
    loaded_args = pickle.load(f)

# Important! Changed the forward function such that it allows input of less edges that are predicted
# and it does not return newly calculated x

class LinkPredModel(torch.nn.Module):
    """
        Architecture contains Batch Normalization layers (https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739) 
        between the SAGEConvolutional layers.
    """
    def __init__(self, input_size, hidden_size, out_size, num_layers, aggr, dropout, device):
        super(LinkPredModel, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(SAGEConv(input_size, hidden_size, normalize=True, aggr=aggr)) # input node embedding features, output is size of hidden layers
        self.bns.append(nn.BatchNorm1d(hidden_size))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_size, hidden_size, normalize=True, aggr=aggr))
            self.bns.append(nn.BatchNorm1d(hidden_size))
        self.convs.append(SAGEConv(hidden_size, out_size, normalize=True, aggr=aggr))
        self.bns.append(nn.BatchNorm1d(out_size))

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        self.device = device


    def forward(self, x, edge_index, edge_label_index, training = False):
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        edge_label_index = edge_label_index.to(self.device)

        for i in range(len(self.convs) - 1):
          x = self.convs[i](x, edge_index)
          x = self.bns[i](x)
          x = F.leaky_relu(x)
          x = F.dropout(x, p=self.dropout, training=training)
        
        x = self.convs[-1](x, edge_index)  
        x = self.bns[-1](x)
        

        nodes_first, nodes_second = edge_label_index
        nodes_first_features = x[nodes_first.long()]
        nodes_second_features = x[nodes_second.long()]
        
        pred = torch.sum(nodes_first_features * nodes_second_features, dim=-1)
        
        return pred
    
    def loss(self, pred, label):
        return self.loss_fn(pred, label)

num_node_features = len(node_features['Embedding'][0])

best_model = LinkPredModel(num_node_features, loaded_args['hidden_dim'], loaded_args['output_dim'], 
                           loaded_args['layers'], loaded_args['aggr'], loaded_args['dropout'], 
                           loaded_args['device']).to(loaded_args['device'])

best_model.load_state_dict(torch.load(f'output/best_model_{dataset_nr}.pth'))

<All keys matched successfully>

# GNNExplainer

In [69]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

In [70]:
explainer = Explainer(
    model=best_model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)

In [75]:
import ast

symptoms_drugs = pd.read_csv(f'output/symptom_drugs_{dataset_nr}.csv', converters={1:ast.literal_eval})
symptoms_drugs

Unnamed: 0,Symptom,Candidates
0,HP:0003707,"[2774, 176, 1043]"
1,HP:0003236,"[1043, 176, 1971]"
2,HP:0001256,"[65, 1273, 1971]"
3,HP:0001265,"[1273, 1971, 176]"
4,HP:0001290,"[1971, 1043, 2630]"
5,HP:0003307,"[176, 2774, 1043]"
6,HP:0003202,"[1273, 1043, 2774]"
7,HP:0001263,"[1273, 1150, 832]"
8,HP:0002791,"[1273, 1646, 271]"
9,HP:0001371,"[1043, 522, 1273]"


In [82]:
def explain_edge(node_idx1, node_idx2):
    edge_label_index = torch.tensor([node_idx1, node_idx2]) # Create a tensor as input of the model. Only predict the label of a single edge instead of all edges in current dataset.
    
    done = False
    current_iterations = 0
    total_iterations = 30
    graph_size = 30

    while not done:
        print(f'--- Iteration {current_iterations} ---')

        explanation = explainer(
            x=DeepG.node_feature,
            edge_index=DeepG.edge_index,
            edge_label_index=edge_label_index,
        )
        
        edge_mask = explanation.edge_mask
        
        print(f'A total of {torch.count_nonzero(edge_mask)} edges contributing to prediction.')
        
        sorted_contributions = edge_mask.sort(descending=True)[0]
        print(f'Maximum contribution value is {sorted_contributions[0]}')
        limit = sorted_contributions[graph_size]
        print(f'Based on desired graph size, contribution limit is set to {limit}')
        
        pruned_G = copy.deepcopy(G_attributes)
        for indx, edge in enumerate(G.edges):
            if edge_mask[indx] < limit:
                pruned_G.remove_edge(edge[0], edge[1])
        
        done = nx.has_path(pruned_G, node_idx1, node_idx2)
        
        pruned_G.remove_nodes_from(list(nx.isolates(pruned_G)))
        print('Pruned graph has edge total of:', pruned_G.number_of_edges())
        print('Pruned graph has node total of:', pruned_G.number_of_nodes())
        
        current_iterations += 1
        
        if done:
            print(f'A good explanation has been found!')
            nx.write_gexf(pruned_G, "output/explanation.gexf")
            return True
            
        elif current_iterations == total_iterations and not done:
            done = True
            print(f'No good explanations found after {current_iterations} iterations...')
    
    return False

In [83]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

found_explanation = []

for ind, edge in symptoms_drugs.iterrows(): 
    symptom_id = edge['Symptom']
    
    for drug_id in edge['Candidates']:
        
        print('Explain for edge between nodes', node_idx1, 'and', node_idx2)

        node_idx1 = get_node_idx(symptom_id, nodes)
        print('node1:')
        print(nodes.loc[[node_idx1]])

        node_idx2 = get_node_idx(str(drug_id), nodes)
        print('node2:')
        print(nodes.loc[[node_idx2]])
        
        explained = explain_edge(node_idx1, node_idx2)
        if explained:
            found_explanation.append({'symptom': symptom_id, 'drug': drug_id})

Explain for edge between nodes 1351 and 7315
node1:
      index_id          id   semantic                          label   
5727      5727  HP:0003707  phenotype  Calf muscle pseudohypertrophy  \

      semantic_id  
5727            9  
node2:
      index_id    id semantic        label  semantic_id
9663      9663  2774     drug  tropicamide            4
--- Iteration 0 ---
A total of 4260 edges contributing to prediction.
Maximum contribution value is 0.17691539227962494
Based on desired graph size, contribution limit is set to 0.14854896068572998
Pruned graph has edge total of: 31
Pruned graph has node total of: 39
--- Iteration 1 ---
A total of 2958 edges contributing to prediction.
Maximum contribution value is 0.17125378549098969
Based on desired graph size, contribution limit is set to 0.14849451184272766
Pruned graph has edge total of: 31
Pruned graph has node total of: 35
--- Iteration 2 ---
A total of 2958 edges contributing to prediction.
Maximum contribution value is 0.165783

KeyboardInterrupt: 

In [None]:
found_explanation