This Jupyter notebook contains the GNNExplainer method.

In [22]:
import pandas as pd
import networkx as nx
import numpy as np
import torch

import pickle
import copy

import matplotlib.pyplot as plt

from gensim.models import KeyedVectors

In [23]:
dataset_nr = 1

# Load all data

Load the nodes

In [24]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,WormBase:WBGene00000389,ORTH,cdc-25.4,5
1,1,ZP:0018675,DISO,right side lateral plate mesoderm mislocalised...,1
2,2,ZFIN:ZDB-GENE-040426-1197,ORTH,tbc1d5,5
3,3,5,DRUG,(S)-nicardipine,2
4,4,RGD:3443,ORTH,Ptk2,5
...,...,...,...,...,...
10029,10029,MP:0009763,DISO,increased sensitivity to induced morbidity/mor...,1
10030,10030,MP:0011057,DISO,absent brain ependyma motile cilia,1
10031,10031,MP:0001412,DISO,excessive scratching,1
10032,10032,WBPhenotype:0004023,DISO,frequency of body bend variant,1


Load the Edge2Vec embedding

In [25]:
node_feat = KeyedVectors.load(f'output/w2v_{dataset_nr}.dvectors', mmap='r')

e2v_embedding = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(node_feat.index_to_key):
    e2v_embedding.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(node_feat[key])})
    
e2v_embedding = e2v_embedding.sort_values('Node')
e2v_embedding

Unnamed: 0,Node,Embedding
0,0,"[0.80579096, 0.039609067, -0.17131972, 0.68869..."
1,1,"[0.13173065, -0.30342472, 0.44653553, 0.617166..."
2,2,"[0.62159157, -0.75673, 0.62674034, -0.35293385..."
3,3,"[0.44389758, -0.25133932, -0.03261666, -0.1510..."
4,4,"[0.44021487, -0.30697218, -0.24358109, 0.24869..."
...,...,...
10029,10029,"[0.35744143, -0.30719844, 0.43343344, -0.10684..."
10030,10030,"[0.5525602, -0.011062832, 0.14208537, 0.322823..."
10031,10031,"[0.61204404, -0.2899803, 0.26087046, -0.607966..."
10032,10032,"[0.29751068, -0.5034509, -0.01351818, -0.03332..."


Load the edges

In [26]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,FlyBase:FBgn0085464,CG34435,5,6825,0
1,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,HGNC:7585,MYL4,3,27,0
2,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,FlyBase:FBgn0002772,Mlc1,5,8901,0
3,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,NCBIGene:396472,MYL4,3,9508,0
4,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in 1 to 1 orthology relationship with,ENSEMBL:ENSECAG00000020967,ENSEMBL:ENSECAG00000020967,5,8807,1
...,...,...,...,...,...,...,...,...,...,...
82908,4810,ibrutinib,2,1618,targets,HGNC:11283,SRC,3,3279,14
82909,522,carvedilol,2,184,targets,HGNC:620,APP,3,547,14
82910,OMIM:300377.0013,"DMD, EX18DEL",1,2822,is allele of,HGNC:2928,DMD,3,6612,17
82911,Coriell:GM05113,NIGMS-GM05113,4,8105,has role in modeling,MONDO:0010679,Duchenne muscular dystrophy,1,6315,15


Initialize a graph with edges and nodes including the embedding features of each node.

In [27]:
G = nx.DiGraph() # TODO: changed from Graph
for ind, node in e2v_embedding.iterrows(): 
    G.add_node(int(node['Node']), node_feature=torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']), edge_label = edge['relation'])
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 8.0579e-01,  3.9609e-02, -1.7132e-01,  6.8870e-01,  3.8687e-01,
        -1.1756e+00,  4.3465e-01, -5.1411e-01, -8.6061e-01, -1.5018e-01,
        -5.6375e-01,  5.2018e-02,  4.2212e-01, -1.4776e-01, -1.6446e-01,
         9.3733e-01, -6.3037e-01, -2.5282e-01,  1.1521e-01,  5.3040e-01,
         7.8000e-01,  3.6201e-01, -2.6033e-01, -1.1161e+00,  6.4173e-01,
         2.9151e-01, -1.0554e+00, -8.0453e-01, -6.6271e-01,  2.2767e-02,
         1.0419e-01,  1.0334e-01,  5.3798e-01, -1.7435e-01,  4.3009e-01,
         1.6690e-01, -4.3212e-01,  1.6400e-01, -2.9133e-01,  1.0707e+00,
        -3.5062e-02,  1.2267e-01,  3.4321e-01, -8.6810e-02,  1.0513e+00,
        -5.3982e-01, -2.1837e-01, -1.5077e-01,  4.7775e-04, -1.2586e-02,
        -7.0311e-01,  1.1769e-01,  2.7447e-01, -2.5308e-02,  6.0167e-01,
         1.3919e-01,  3.6101e-02, -5.8889e-03, -4.1151e-01,  2.1858e-01,
        -7.8684e-01, -1.7285e-01,  6.9146e-01, -7.3086e-02])})
(0, 6308, {'edge_label': 'in orthology r

In [28]:
edge_labels_dict = dict([((n1, n2), G.edges[(n1,n2)]['edge_label']) for n1, n2 in G.edges])
edge_labels_dict

{(0, 6308): 'in orthology relationship with',
 (0, 9835): 'in 1 to 1 orthology relationship with',
 (0, 1681): 'in orthology relationship with',
 (0, 3670): 'in 1 to 1 orthology relationship with',
 (0, 5156): 'in orthology relationship with',
 (0, 2619): 'in orthology relationship with',
 (0, 2226): 'in orthology relationship with',
 (0, 10028): 'in 1 to 1 orthology relationship with',
 (0, 363): 'in orthology relationship with',
 (0, 8710): 'in orthology relationship with',
 (0, 8473): 'in orthology relationship with',
 (0, 8615): 'in orthology relationship with',
 (0, 5765): 'interacts with',
 (0, 5949): 'in orthology relationship with',
 (0, 3961): 'in orthology relationship with',
 (0, 5248): 'in 1 to 1 orthology relationship with',
 (0, 8756): 'in orthology relationship with',
 (0, 2139): 'in orthology relationship with',
 (0, 9308): 'has phenotype',
 (0, 71): 'has phenotype',
 (0, 1856): 'in orthology relationship with',
 (0, 6666): 'in orthology relationship with',
 (2, 1154): 

Load trained model

In [29]:
from gnn.linkpred_model import LinkPredModel

with open(f'output/best_model_{dataset_nr}_args.pkl', 'rb') as f:
    loaded_args = pickle.load(f)
    
best_model = LinkPredModel(len(e2v_embedding['Embedding'][0]), 
                           loaded_args['hidden_dim'], loaded_args['output_dim'], 
                           loaded_args['layers'], loaded_args['aggr'], 
                           loaded_args['dropout'], loaded_args['device']).to(loaded_args['device'])
best_model.load_state_dict(torch.load(f'output/best_model_{dataset_nr}.pth'))

<All keys matched successfully>

# Explain predictions

In [30]:
x = torch.Tensor(e2v_embedding['Embedding'])
print(x.shape)

torch.Size([10034, 64])


In [31]:
edge_index = torch.Tensor(np.array(G.edges).transpose()).type(torch.int64).long()
print(edge_index.shape)

torch.Size([2, 82899])


In [32]:
from gnn.gnnexplainer import GNNExplainer, visualize_subgraph

def explain_edge(node_idx1, node_idx2):
    explainer = GNNExplainer(best_model,
                             epochs=700, num_hops=1, lr=0.01)
    
    trigger = False
    early_stop = 0
    size = 15   # change size of the explanation graph
    iterations = 50 # number of times GNNExplainer is executed
    
    while not trigger:
        _, edge_mask = explainer.explain_link(node_idx1=node_idx1, node_idx2=node_idx2,
                                              x=x, edge_index=edge_index,
                                              G=G)
        
        G2 = copy.deepcopy(G)
        
        if size is not None:
            limit = edge_mask.sort(descending = True)[0][size]
            print('Contribution threshold is', limit)
        else:
            limit = 0.5
        
        for indx, edge in enumerate(G.edges): 
            if edge_mask[indx] < limit:
                G2.remove_edge(edge[0], edge[1])
        
        trigger = nx.has_path(G2, node_idx1, node_idx2)
        early_stop += 1
        
        if early_stop == iterations and not trigger:
            print('No good explanation found after {} iterations'.format(early_stop))
            trigger = True
        elif trigger: 
            print('Explanation found!')
            
            plt.figure(figsize=(10, 10))
            ax, G_sub = visualize_subgraph([node_idx1, node_idx2], 
                                           edge_index, edge_mask, 
                                           nodes=nodes, y=torch.Tensor(nodes.semantic_id), 
                                           seed=667, num_hops=2, threshold=limit, 
                                           node_label='label', edge_labels=edge_labels_dict, 
                                           show_inactive=False, remove_unconnected=True)
            plt.show()
            
            return True
        
    return False
    

In [33]:
import ast

symptoms_drugs = pd.read_csv(f'output/symptom_drugs_{dataset_nr}.csv', converters={1:ast.literal_eval})
symptoms_drugs

Unnamed: 0,Symptom,Candidates
0,HP:0011675,"[2974, 4285, 6]"
1,HP:0002515,"[5302, 4187, 4285]"
2,HP:0003236,"[2974, 4285, 5345]"
3,HP:0002093,"[4285, 5302, 5345]"
4,HP:0003707,"[1529, 5345, 4187]"
5,HP:0001256,"[4285, 5345, 5302]"
6,HP:0003701,"[4285, 5302, 5345]"
7,HP:0003202,"[4285, 2835, 624]"
8,HP:0003560,"[2612, 624, 4285]"
9,HP:0003391,"[5302, 4285, 4187]"


In [35]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

found_explanation = []

for ind, edge in symptoms_drugs.iterrows(): 
    symptom_id = edge['Symptom']
    
    for drug_id in edge['Candidates']:
        node_idx1 = get_node_idx(symptom_id, nodes)
        print('node1:')
        print(nodes.loc[[node_idx1]])

        node_idx2 = get_node_idx(str(drug_id), nodes)
        print('node2:')
        print(nodes.loc[[node_idx2]])
        
        explained = explain_edge(node_idx1, node_idx2)
        
        if explained:
            found_explanation.append({'symptom': symptom_id, 'drug': drug_id})

node1:
      index_id          id semantic       label  semantic_id
9512      9512  HP:0011675     DISO  Arrhythmia            1
node2:
      index_id    id semantic       label  semantic_id
1019      1019  2974     DRUG  doxacurium            2
Explain for edge between nodes 9512 and 1019


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:10<00:00, 65.01it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:10<00:00, 67.15it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:09<00:00, 71.55it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:10<00:00, 64.99it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:11<00:00, 60.91it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:10<00:00, 68.76it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:09<00:00, 75.91it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:09<00:00, 77.45it/s]


Contribution threshold is tensor(0.0086)


Explain edge between nodes 9512 and 1019: 100%|██████████| 700/700 [00:09<00:00, 72.82it/s]


Contribution threshold is tensor(0.0086)


KeyboardInterrupt: 

In [None]:
found_explanation