This Jupyter notebook contains the GNNExplainer method.

In [6]:
import pandas as pd
import networkx as nx

import torch

from gensim.models import KeyedVectors

In [1]:
dataset_nr = 2

# Load all data

Load the Edge2Vec embedding

In [4]:
node_feat = KeyedVectors.load(f'output/w2v_{dataset_nr}.dvectors', mmap='r')

e2v_embedding = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(node_feat.index_to_key):
    e2v_embedding.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(node_feat[key])})
    
e2v_embedding = e2v_embedding.sort_values('Node')
e2v_embedding

Unnamed: 0,Node,Embedding
0,0,"[0.4429384, -0.64671946, -0.045503706, 0.04694..."
1,1,"[0.7495496, -0.5790302, 0.16192597, 0.03131027..."
2,2,"[-0.21050918, -0.28908953, 0.118277326, -0.162..."
3,3,"[0.2527442, -0.2967686, 0.29164207, 0.8072669,..."
4,4,"[0.41442072, 0.06997513, 0.4540072, -0.0370618..."
...,...,...
10270,10270,"[0.21862394, 0.11319974, 0.31551972, 0.1497048..."
10271,10271,"[0.8747792, -0.7461259, 0.22221777, -0.0425329..."
10272,10272,"[0.5863251, -0.8054981, 0.27565393, 0.516578, ..."
10273,10273,"[0.6474387, -0.38251266, 0.43746328, 0.3413055..."


Load the edges

In [5]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00003929,pat-2,5,1542,0
1,WormBase:WBGene00006787,unc-52,5,304,interacts with,WormBase:WBGene00006789,unc-54,5,6544,0
2,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSSSCG00000015555,LAMC1,5,9268,1
3,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ZFIN:ZDB-GENE-021226-3,lamc1,5,5387,1
4,WormBase:WBGene00006787,unc-52,5,304,in orthology relationship with,ENSEMBL:ENSOANG00000001050,ENSEMBL:ENSOANG00000001050,5,2204,1
...,...,...,...,...,...,...,...,...,...,...
85987,458,scopolamine butylbromide,4,5945,targets,P11229,Muscarinic acetylcholine receptor M1,6,5919,17
85988,OMIM:300377.0080,"DMD, IVS62, A-G, -285",11,1578,is allele of,HGNC:2928,DMD,5,3310,15
85989,5297,dacomitinib,4,8798,targets,P12931,Proto-oncogene tyrosine-protein kinase Src,6,2379,17
85990,ClinVarVariant:981988,NC_000023.11:g.(31875374_31929595)_(31968515_3...,11,8189,has affected feature,HGNC:2928,DMD,5,3310,11


Initialize a graph with edges and nodes including the embedding features of each node.

In [7]:
G = nx.DiGraph() # TODO: changed from Graph
for ind, node in e2v_embedding.iterrows(): 
    G.add_node(int(node['Node']), node_feature=torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']))
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 0.4429, -0.6467, -0.0455,  0.0469,  0.3007, -0.4822,  0.2423,  0.0707,
        -0.5349,  0.4175,  0.7552,  0.0628, -0.3813, -0.3389,  0.5671,  0.4919,
         0.0854,  0.0053, -0.2314,  0.8349,  0.5443,  0.7359, -0.1399, -0.2049,
        -0.1793,  0.0314, -0.7152,  0.1764,  0.3586, -0.4770,  0.0838, -0.3013,
         0.1751, -0.7785, -0.1280,  0.4082,  0.5959,  0.4058,  0.5182,  0.5561,
         0.3995, -0.1704,  0.3317, -0.1183, -0.1047,  0.2789,  0.3293, -0.1368,
         0.1259,  0.1704,  0.7724,  0.4460,  0.6070, -0.3315,  0.5776,  0.4367,
         0.2875, -0.7309,  0.2535,  0.5972,  0.4410, -0.4248, -0.3355,  0.3258])})
(2, 7693, {})
Number of edges is 85878
Number of nodes is 10275


# Explain predictions

In [None]:
def explain_edge(node_idx1, node_idx2):
    explainer = 