In [1]:
import pandas as pd
import networkx as nx
import pickle
import copy

from gensim.models import KeyedVectors

import torch
import torch.nn as nn
import torch.nn.functional as F

from deepsnap.graph import Graph

from torch_geometric.nn import SAGEConv
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig


# Load Necessary Data

For the node features, load the embeddings yielded from Edge2Vec.

In [2]:
loaded_node_embedding = KeyedVectors.load('output/w2v.dvectors', mmap='r')

node_embeddings = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(loaded_node_embedding.index_to_key):
    node_embeddings.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(loaded_node_embedding[key])})
    
node_features = node_embeddings.sort_values('Node')

node_features

Unnamed: 0,Node,Embedding
0,0,"[0.025715737, -0.21057369, 0.03792795, 0.25484..."
1,1,"[0.18812153, -0.1237438, -0.007631695, 0.13901..."
2,2,"[0.56500244, 0.06161292, 0.7965068, 0.28490803..."
3,3,"[0.40217963, -0.4195659, -0.0022494257, 0.4107..."
4,4,"[0.18838538, -0.45605302, 0.12164134, 0.125727..."
...,...,...
10232,10232,"[0.52268314, -0.26845437, -0.17985809, 0.13895..."
10233,10233,"[0.44407392, -0.21608466, 0.30382442, 0.208148..."
10234,10234,"[0.33063313, -0.20261967, 0.14750157, 0.319578..."
10235,10235,"[0.3033455, -0.19828539, -0.02041113, 0.050759..."


For additional information of nodes, load csv file.

In [3]:
nodes = pd.read_csv('output/indexed_nodes.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,MP:0020358,phenotype,abnormal inhibitory synapse morphology,8
1,1,HP:0000670,phenotype,Carious teeth,8
2,2,ENSEMBL:ENSCAFG00000017522,gene,ENSEMBL:ENSCAFG00000017522,5
3,3,ZFIN:ZDB-GENE-070705-188,gene,chrm1b,5
4,4,HP:0000544,phenotype,External ophthalmoplegia,8
...,...,...,...,...,...
10232,10232,MGI:2176879,genotype,Dmd<mdx>/Dmd<mdx>; Utrn<tm1Ked>/Utrn<tm1Ked> [...,6
10233,10233,ZP:0008070,phenotype,"skeletal muscle cell dystrophic, abnormal",8
10234,10234,ClinVarVariant:577742,variant,NM_004006.2(DMD):c.8655C>A (p.Tyr2885Ter),10
10235,10235,MP:0004087,phenotype,abnormal muscle fiber morphology,8


For the edges, load csv file.

In [4]:
edges = pd.read_csv('output/indexed_edges.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,FlyBase:FBgn0011676,Nos,5,9594,in orthology relationship with,ENSEMBL:ENSCAFG00000009820,ENSEMBL:ENSCAFG00000009820,5,7362,0
1,HGNC:12012,TPM3,5,1813,interacts with,HGNC:11946,TNNI2,5,7360,1
2,HGNC:15516,XYLT1,5,3228,causes condition,HP:0002650,Scoliosis,8,2102,2
3,WormBase:WBGene00000065,act-3,5,3272,in orthology relationship with,ENSEMBL:ENSOANG00000007850,ENSEMBL:ENSOANG00000007850,5,2284,0
4,HGNC:6485,LAMA5,5,4641,interacts with,HGNC:329,AGRN,5,2205,1
...,...,...,...,...,...,...,...,...,...,...
85949,dictyBase:DDB_G0276459,pakB,5,3746,in orthology relationship with,WormBase:WBGene00003911,pak-1,5,6906,0
85950,ZFIN:ZDB-GENE-030113-2,ttn.2,5,6355,causes condition,ZP:0106692,"Z disc myocardium disorganized, abnormal",8,6931,2
85951,ENSEMBL:ENSGALG00000015211,DTNA,5,828,in orthology relationship with,ENSEMBL:ENSECAG00000020912,ENSEMBL:ENSECAG00000020912,5,6757,0
85952,ENSEMBL:ENSECAG00000000207,ENSEMBL:ENSECAG00000000207,5,8915,in orthology relationship with,WormBase:WBGene00000064,act-2,5,5490,0


Load nodes with their features and edges into graph.

In [5]:
G = nx.DiGraph()
for ind, node in node_features.iterrows(): 
    G.add_node(int(node['Node']), node_feature=torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']))
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 0.0257, -0.2106,  0.0379,  0.2548, -0.0710, -0.1161,  0.0487, -0.0889,
        -0.2617,  0.0058,  0.1430, -0.2625, -0.0045,  0.0088, -0.0159,  0.2041,
        -0.0750, -0.0746, -0.1348,  0.5121,  0.3134,  0.3092, -0.0215, -0.1648,
        -0.1561,  0.3388, -0.3543, -0.0308, -0.0959, -0.1503, -0.0089,  0.0727,
        -0.2139, -0.1520, -0.0370,  0.2890, -0.0060,  0.0132,  0.1780,  0.0230,
         0.0030,  0.1089, -0.0066,  0.0677,  0.0310,  0.0451,  0.0139,  0.0874,
        -0.0229,  0.2928,  0.1286,  0.0255,  0.1482, -0.0255,  0.1047,  0.1834,
         0.0289, -0.0020, -0.1158,  0.1742, -0.0059, -0.2629,  0.0610,  0.0691])})
(2, 835, {})
Number of edges is 85840
Number of nodes is 10237


In [6]:
DeepG = Graph(G)

print(DeepG.node_feature.shape)
print(DeepG.edge_index.shape)

torch.Size([10237, 64])
torch.Size([2, 85840])


In [7]:
x = DeepG.node_feature
x

tensor([[ 0.0257, -0.2106,  0.0379,  ..., -0.2629,  0.0610,  0.0691],
        [ 0.1881, -0.1237, -0.0076,  ..., -0.0790, -0.0277, -0.0832],
        [ 0.5650,  0.0616,  0.7965,  ..., -0.4596, -0.2780,  0.3672],
        ...,
        [ 0.3306, -0.2026,  0.1475,  ..., -0.0057,  0.0220, -0.0708],
        [ 0.3033, -0.1983, -0.0204,  ...,  0.0915,  0.1807,  0.2646],
        [ 0.7833, -0.0757,  0.5275,  ..., -0.4399,  0.0885, -0.2627]])

Load arguments and parameters of trained GNN model.

In [8]:
with open('output/best_model_args.pkl', 'rb') as f:
    loaded_args = pickle.load(f)

# Important! Changed the forward function such that it allows input of less edges that are predicted
# and it does not return newly calculated x

class LinkPredModel(torch.nn.Module):
    """
        Architecture contains Batch Normalization layers (https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739) 
        between the SAGEConvolutional layers.
    """
    def __init__(self, input_size, hidden_size, out_size, num_layers, aggr, dropout, device):
        super(LinkPredModel, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(SAGEConv(input_size, hidden_size, normalize=True, aggr=aggr)) # input node embedding features, output is size of hidden layers
        self.bns.append(nn.BatchNorm1d(hidden_size))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_size, hidden_size, normalize=True, aggr=aggr))
            self.bns.append(nn.BatchNorm1d(hidden_size))
        self.convs.append(SAGEConv(hidden_size, out_size, normalize=True, aggr=aggr))
        self.bns.append(nn.BatchNorm1d(out_size))

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        self.device = device


    def forward(self, x, edge_index, edge_label_index, training = False):
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        edge_label_index = edge_label_index.to(self.device)

        for i in range(len(self.convs) - 1):
          x = self.convs[i](x, edge_index)
          x = self.bns[i](x)
          x = F.leaky_relu(x)
          x = F.dropout(x, p=self.dropout, training=training)
        
        x = self.convs[-1](x, edge_index)  
        x = self.bns[-1](x)
        

        nodes_first, nodes_second = edge_label_index
        nodes_first_features = x[nodes_first.long()]
        nodes_second_features = x[nodes_second.long()]
        
        pred = torch.sum(nodes_first_features * nodes_second_features, dim=-1)
        
        return pred
    
    def loss(self, pred, label):
        return self.loss_fn(pred, label)

num_node_features = len(node_features['Embedding'][0])

best_model = LinkPredModel(num_node_features, loaded_args['hidden_dim'], loaded_args['output_dim'], 
                           loaded_args['layers'], loaded_args['aggr'], loaded_args['dropout'], 
                           loaded_args['device']).to(loaded_args['device'])

best_model.load_state_dict(torch.load('output/best_model.pth'))

<All keys matched successfully>

# GNNExplainer

In [9]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

In [10]:
explainer = Explainer(
    model=best_model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=700),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)

Get the node indices that are relevant to the edge of interest.

In [11]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

node_idx1 = get_node_idx('HP:0003560', nodes)
print('node1:')
print(nodes.loc[[node_idx1]])

node_idx2 = get_node_idx('522', nodes)
print('node2:')
print(nodes.loc[[node_idx2]])

node1:
     index_id          id   semantic               label  semantic_id
184       184  HP:0003560  phenotype  Muscular dystrophy            8
node2:
       index_id   id semantic       label  semantic_id
10137     10137  522     drug  carvedilol            4


Create a tensor as input of the model. Only predict the label of a single edge instead of all edges in current dataset.

In [12]:
edge_label_index = torch.tensor([node_idx1, node_idx2])
edge_label_index

tensor([  184, 10137], dtype=torch.int32)

In [13]:
done = False
current_iterations = 0
total_iterations = 10
graph_size = 15 

while not done:
    print(f'--- Iteration {current_iterations} ---')

    explanation = explainer(
        x=DeepG.node_feature,
        edge_index=DeepG.edge_index,
        edge_label_index=edge_label_index,
    )
    
    edge_mask = explanation.edge_mask
    
    print(f'A total of {torch.count_nonzero(edge_mask)} edges contributing to prediction.')
    
    sorted_contributions = edge_mask.sort(descending=True)[0]
    print(f'Maximum contribution value is {sorted_contributions[0]}')
    limit = sorted_contributions[graph_size]
    print(f'Based on desired graph size, contribution limit is set to {limit}')
    
    pruned_G = copy.deepcopy(G)
    for indx, edge in enumerate(G.edges):
        if edge_mask[indx] < limit:
            pruned_G.remove_edge(edge[0], edge[1])
    
    done = nx.has_path(pruned_G, node_idx1, node_idx2)
    
    current_iterations += 1
    
    if done:
        print(f'A good explanation has been found!')
    elif current_iterations == total_iterations and not done:
        done = True
        print(f'No good explanations found after {current_iterations} iterations...')

    

--- Iteration 0 ---
A total of 3687 edges contributing to prediction.
Maximum contribution value is 0.030153825879096985
Based on desired graph size, contribution limit is set to 0.02365393005311489
--- Iteration 1 ---
A total of 3687 edges contributing to prediction.
Maximum contribution value is 0.029727665707468987
Based on desired graph size, contribution limit is set to 0.023980459198355675
--- Iteration 2 ---
A total of 3687 edges contributing to prediction.
Maximum contribution value is 0.029797092080116272
Based on desired graph size, contribution limit is set to 0.02377462200820446
--- Iteration 3 ---
A total of 3687 edges contributing to prediction.
Maximum contribution value is 0.029377831146121025
Based on desired graph size, contribution limit is set to 0.023571254685521126
--- Iteration 4 ---
A total of 3687 edges contributing to prediction.
Maximum contribution value is 0.030073348432779312
Based on desired graph size, contribution limit is set to 0.0237234178930521
--- 