In [1]:
import pandas as pd
import networkx as nx
import pickle
import copy

from gensim.models import KeyedVectors

import torch
import torch.nn as nn
import torch.nn.functional as F

from deepsnap.graph import Graph

from torch_geometric.nn import SAGEConv
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig


In [2]:
dataset_nr = 2

# Load Necessary Data

For the node features, load the embeddings yielded from Edge2Vec.

In [3]:
loaded_node_embedding = KeyedVectors.load(f'output/w2v_{dataset_nr}.dvectors', mmap='r')

node_embeddings = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(loaded_node_embedding.index_to_key):
    node_embeddings.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(loaded_node_embedding[key])})
    
node_features = node_embeddings.sort_values('Node')

node_features

Unnamed: 0,Node,Embedding
0,0,"[0.3984332, -0.31286958, 0.5400291, 0.50470465..."
1,1,"[0.44448715, -0.21083395, 0.77053374, 0.059516..."
2,2,"[0.06366989, 0.3194347, 1.0812027, 0.20341216,..."
3,3,"[0.1271398, 0.03148787, 0.27353236, 0.39875212..."
4,4,"[0.5111069, -0.14326268, 0.4403008, 0.511082, ..."
...,...,...
10270,10270,"[0.14994565, 0.14421983, 0.35291353, 0.1048778..."
10271,10271,"[0.57958156, 0.07746426, 0.004016864, -0.28448..."
10272,10272,"[0.74247813, -0.6906312, 0.66188675, 0.8234299..."
10273,10273,"[-0.22202384, -0.9451154, 0.3080682, 0.2242091..."


For additional information of nodes, load csv file.

In [4]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,MP:0004187,phenotype,cardia bifida,9
1,1,ZP:0100138,phenotype,muscle tendon junction myotome increased amoun...,9
2,2,MGI:1346525,gene,Sgcd,5
3,3,OMIM:300377.0044,variant,"DMD, LYS770TER",11
4,4,ZP:0002210,phenotype,posterior lateral line neuromast primordium mi...,9
...,...,...,...,...,...
10270,10270,ZP:0014934,phenotype,atrioventricular valve development process qua...,9
10271,10271,ENSEMBL:ENSCAFG00000011207,gene,ENSEMBL:ENSCAFG00000011207,5
10272,10272,ENSEMBL:ENSXETG00000039922,gene,ENSEMBL:ENSXETG00000039922,5
10273,10273,ENSEMBL:ENSACAG00000010058,gene,ENSEMBL:ENSACAG00000010058,5


In [5]:
node_features = pd.merge(node_features, nodes,
                         left_on='Node', right_on='index_id',
                         how='inner')
node_features

Unnamed: 0,Node,Embedding,index_id,id,semantic,label,semantic_id
0,0,"[0.3984332, -0.31286958, 0.5400291, 0.50470465...",0,MP:0004187,phenotype,cardia bifida,9
1,1,"[0.44448715, -0.21083395, 0.77053374, 0.059516...",1,ZP:0100138,phenotype,muscle tendon junction myotome increased amoun...,9
2,2,"[0.06366989, 0.3194347, 1.0812027, 0.20341216,...",2,MGI:1346525,gene,Sgcd,5
3,3,"[0.1271398, 0.03148787, 0.27353236, 0.39875212...",3,OMIM:300377.0044,variant,"DMD, LYS770TER",11
4,4,"[0.5111069, -0.14326268, 0.4403008, 0.511082, ...",4,ZP:0002210,phenotype,posterior lateral line neuromast primordium mi...,9
...,...,...,...,...,...,...,...
10270,10270,"[0.14994565, 0.14421983, 0.35291353, 0.1048778...",10270,ZP:0014934,phenotype,atrioventricular valve development process qua...,9
10271,10271,"[0.57958156, 0.07746426, 0.004016864, -0.28448...",10271,ENSEMBL:ENSCAFG00000011207,gene,ENSEMBL:ENSCAFG00000011207,5
10272,10272,"[0.74247813, -0.6906312, 0.66188675, 0.8234299...",10272,ENSEMBL:ENSXETG00000039922,gene,ENSEMBL:ENSXETG00000039922,5
10273,10273,"[-0.22202384, -0.9451154, 0.3080682, 0.2242091...",10273,ENSEMBL:ENSACAG00000010058,gene,ENSEMBL:ENSACAG00000010058,5


For the edges, load csv file.

In [6]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00003929,pat-2,5,1542,0
1,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00006789,unc-54,5,6544,0
2,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,5,9268,1
3,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,5,5387,1
4,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,5,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,4,5945,targets,P11229,Muscarinic acetylcholine receptor M1,6,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",11,1578,is allele of,HGNC:2928,DMD,5,3310,15
85989,5297,dacomitinib,4,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,6,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,11,8189,has affected feature,HGNC:2928,DMD,5,3310,11


Load nodes with their features and edges into graph.

In [7]:
G = nx.Graph()
for ind, node in node_features.iterrows(): 
    G.add_node(int(node['Node']), node_feature = torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']))
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 0.3984, -0.3129,  0.5400,  0.5047,  0.3606, -0.1538, -0.4962, -0.2255,
        -0.3281, -0.0869,  0.2944, -0.4044,  0.2928, -0.4601,  0.1927, -0.1504,
        -0.4191, -0.1028, -0.3842,  0.1555,  0.4455,  0.3406,  0.0176, -0.4781,
        -0.3071,  0.6394, -0.4143, -0.0489,  0.0523, -0.2817, -0.2545, -0.5112,
        -0.1196, -0.2351,  0.3489,  0.1787,  0.1554,  0.1126,  0.6581,  0.2639,
        -0.0216, -0.2815, -0.3214, -0.1436,  0.1828, -0.0387,  0.1260,  0.1041,
         0.0978,  0.0397,  0.1876,  0.1959,  0.2125,  0.5915,  0.2164, -0.2079,
        -0.3812, -0.3433, -0.0945,  0.2793, -0.4201, -0.2743, -0.4724, -0.3125])})
(0, 8472, {})
Number of edges is 55032
Number of nodes is 10275


In [8]:
G_attributes = nx.Graph()
for ind, node in node_features.iterrows(): 
    G_attributes.add_node(int(node['Node']), node_label = node['label'], node_semantic = node['semantic'])
for ind, edge in edges.iterrows(): 
    G_attributes.add_edge(int(edge['index_head']), int(edge['index_tail']), edge_label = edge['relation'])
  
for node in G_attributes.nodes(data=True):
    print(node)
    break
for edge in G_attributes.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G_attributes.number_of_edges()))
print("Number of nodes is {}".format(G_attributes.number_of_nodes()))

(0, {'node_label': 'cardia bifida', 'node_semantic': 'phenotype'})
(0, 8472, {'edge_label': 'causes condition'})
Number of edges is 55032
Number of nodes is 10275


In [9]:
DeepG = Graph(G)

print(DeepG.node_feature.shape)
print(DeepG.edge_index.shape)

torch.Size([10275, 64])
torch.Size([2, 110064])


In [10]:
x = DeepG.node_feature
x

tensor([[ 0.3984, -0.3129,  0.5400,  ..., -0.2743, -0.4724, -0.3125],
        [ 0.4445, -0.2108,  0.7705,  ..., -0.3755,  0.3339,  0.6621],
        [ 0.0637,  0.3194,  1.0812,  ...,  0.1435,  0.2400,  0.2337],
        ...,
        [ 0.7425, -0.6906,  0.6619,  ..., -0.2799,  0.4221,  0.3836],
        [-0.2220, -0.9451,  0.3081,  ..., -0.4940,  0.8063, -0.3299],
        [ 0.2752, -0.0174, -0.1116,  ..., -0.2055, -0.1233,  0.1177]])

Load arguments and parameters of trained GNN model.

In [11]:
with open(f'output/best_model_{dataset_nr}_args.pkl', 'rb') as f:
    loaded_args = pickle.load(f)

# Important! Changed the forward function such that it allows input of less edges that are predicted
# and it does not return newly calculated x

class LinkPredModel(torch.nn.Module):
    """
        Architecture contains Batch Normalization layers (https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739) 
        between the SAGEConvolutional layers.
    """
    def __init__(self, input_size, hidden_size, out_size, num_layers, aggr, dropout, device):
        super(LinkPredModel, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(SAGEConv(input_size, hidden_size, normalize=True, aggr=aggr)) # input node embedding features, output is size of hidden layers
        self.bns.append(nn.BatchNorm1d(hidden_size))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_size, hidden_size, normalize=True, aggr=aggr))
            self.bns.append(nn.BatchNorm1d(hidden_size))
        self.convs.append(SAGEConv(hidden_size, out_size, normalize=True, aggr=aggr))
        self.bns.append(nn.BatchNorm1d(out_size))

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        self.device = device


    def forward(self, x, edge_index, edge_label_index, training = False):
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        edge_label_index = edge_label_index.to(self.device)

        for i in range(len(self.convs) - 1):
          x = self.convs[i](x, edge_index)
          x = self.bns[i](x)
          x = F.leaky_relu(x)
          x = F.dropout(x, p=self.dropout, training=training)
        
        x = self.convs[-1](x, edge_index)  
        x = self.bns[-1](x)
        

        nodes_first, nodes_second = edge_label_index
        nodes_first_features = x[nodes_first.long()]
        nodes_second_features = x[nodes_second.long()]
        
        pred = torch.sum(nodes_first_features * nodes_second_features, dim=-1)
        
        return pred
    
    def loss(self, pred, label):
        return self.loss_fn(pred, label)

num_node_features = len(node_features['Embedding'][0])

best_model = LinkPredModel(num_node_features, loaded_args['hidden_dim'], loaded_args['output_dim'], 
                           loaded_args['layers'], loaded_args['aggr'], loaded_args['dropout'], 
                           loaded_args['device']).to(loaded_args['device'])

best_model.load_state_dict(torch.load(f'output/best_model_{dataset_nr}.pth'))

<All keys matched successfully>

# GNNExplainer

In [12]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

In [13]:
explainer = Explainer(
    model=best_model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)

In [14]:
import ast

symptoms_drugs = pd.read_csv(f'output/symptom_drugs_{dataset_nr}.csv', converters={1:ast.literal_eval})
symptoms_drugs

Unnamed: 0,Symptom,Candidates
0,HP:0003707,"[269, 1576, 926]"
1,HP:0003236,"[1576, 269, 1795]"
2,HP:0001256,"[269, 1576, 5359]"
3,HP:0001265,"[269, 1576, 1795]"
4,HP:0001290,"[1795, 269, 1576]"
5,HP:0003307,"[1576, 1795, 522]"
6,HP:0003202,"[1795, 1576, 5330]"
7,HP:0001263,"[269, 5359, 1795]"
8,HP:0002791,"[1576, 269, 5359]"
9,HP:0001371,"[1576, 1795, 522]"


In [21]:
graph_nr = 0

def explain_edge(node_idx1, node_idx2):
    edge_label_index = torch.tensor([node_idx1, node_idx2]) # Create a tensor as input of the model. Only predict the label of a single edge instead of all edges in current dataset.
    
    done = False
    current_iterations = 0
    total_iterations = 30
    graph_size = 15

    while not done:
        print(f'--- Iteration {current_iterations} ---')

        explanation = explainer(
            x=DeepG.node_feature,
            edge_index=DeepG.edge_index,
            edge_label_index=edge_label_index,
        )
        
        edge_mask = explanation.edge_mask
        
        print(f'A total of {torch.count_nonzero(edge_mask)} edges contributing to prediction.')
        
        sorted_contributions = edge_mask.sort(descending=True)[0]
        print(f'Maximum contribution value is {sorted_contributions[0]}')
        limit = sorted_contributions[graph_size]
        print(f'Based on desired graph size, contribution limit is set to {limit}')
        
        pruned_G = copy.deepcopy(G_attributes)
        for indx, edge in enumerate(G.edges):
            if edge_mask[indx] < limit:
                pruned_G.remove_edge(edge[0], edge[1])
        
        done = nx.has_path(pruned_G, node_idx1, node_idx2)
        
        pruned_G.remove_nodes_from(list(nx.isolates(pruned_G)))
        print('Pruned graph has edge total of:', pruned_G.number_of_edges())
        print('Pruned graph has node total of:', pruned_G.number_of_nodes())
        
        current_iterations += 1
        
        if done:
            print(f'A good explanation has been found!')
            nx.write_gexf(pruned_G, f"output/explanation{graph_nr}.gexf")
            graph_nr += 1
            return graph_nr
            
        elif current_iterations == total_iterations and not done:
            done = True
            print(f'No good explanations found after {current_iterations} iterations...')
    
    return None

In [22]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

found_explanations = []

for ind, edge in symptoms_drugs.iterrows(): 
    symptom_id = edge['Symptom']
    
    for drug_id in edge['Candidates']:
        node_idx1 = get_node_idx(symptom_id, nodes)
        print('node1:')
        print(nodes.loc[[node_idx1]])

        node_idx2 = get_node_idx(str(drug_id), nodes)
        print('node2:')
        print(nodes.loc[[node_idx2]])
        
        new_graph_nr = explain_edge(node_idx1, node_idx2)
        
        if new_graph_nr is not None:
            found_explanations.append({'symptom': symptom_id, 'drug': drug_id, 'graphnr': new_graph_nr})

node1:
      index_id          id   semantic                          label   
5727      5727  HP:0003707  phenotype  Calf muscle pseudohypertrophy  \

      semantic_id  
5727            9  
node2:
      index_id    id semantic         label  semantic_id
5262      5262  1576     drug  levosimendan            4
--- Iteration 0 ---
A total of 3542 edges contributing to prediction.
Maximum contribution value is 0.310735285282135
Based on desired graph size, contribution limit is set to 0.14891646802425385
Pruned graph has edge total of: 9
Pruned graph has node total of: 12
--- Iteration 1 ---
A total of 3542 edges contributing to prediction.
Maximum contribution value is 0.29607337713241577
Based on desired graph size, contribution limit is set to 0.15019941329956055
Pruned graph has edge total of: 8
Pruned graph has node total of: 11
--- Iteration 2 ---
A total of 3542 edges contributing to prediction.
Maximum contribution value is 0.32425495982170105
Based on desired graph size, contri

KeyboardInterrupt: 

In [None]:
found_explanations