In [22]:
import pandas as pd
import networkx as nx
import pickle

from gensim.models import KeyedVectors

import torch
import torch.nn as nn
import torch.nn.functional as F

from deepsnap.graph import Graph

from torch_geometric.nn import SAGEConv
from torch_geometric.explain import Explainer, GNNExplainer, ModelConfig


# Load Necessary Data

For the node features, load the embeddings yielded from Edge2Vec.

In [10]:
loaded_node_embedding = KeyedVectors.load('output/w2v.dvectors', mmap='r')

node_embeddings = pd.DataFrame(columns = ['Node', 'Embedding'])
for idx, key in enumerate(loaded_node_embedding.index_to_key):
    node_embeddings.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(loaded_node_embedding[key])})
    
node_features = node_embeddings.sort_values('Node')

node_features

Unnamed: 0,Node,Embedding
0,0,"[0.13517001, -0.25358334, 0.07978077, 0.045121..."
1,1,"[0.16805942, -0.8350107, 0.19610697, 0.1981789..."
2,2,"[0.2186089, 0.50546986, 0.6193271, 0.0123952, ..."
3,3,"[0.29987955, -0.30054563, 0.15258668, 0.367052..."
4,4,"[0.1618276, -0.36085665, 0.295463, 0.17283706,..."
...,...,...
10232,10232,"[0.3500288, -1.1846579, 0.15232153, 0.05306821..."
10233,10233,"[0.17927206, -0.17169905, -0.12125319, 0.18529..."
10234,10234,"[0.24352823, -0.24893956, 0.22431412, 0.153569..."
10235,10235,"[0.3218023, -0.7156479, 0.30151767, -0.3144425..."


For additional information of nodes, load csv file.

In [26]:
nodes = pd.read_csv('output/indexed_nodes.csv')
nodes

Unnamed: 0,id,semantic,label,semantic_id
0,MP:0020358,phenotype,abnormal inhibitory synapse morphology,8
1,HP:0000670,phenotype,Carious teeth,8
2,ENSEMBL:ENSCAFG00000017522,gene,ENSEMBL:ENSCAFG00000017522,5
3,ZFIN:ZDB-GENE-070705-188,gene,chrm1b,5
4,HP:0000544,phenotype,External ophthalmoplegia,8
...,...,...,...,...
10232,MGI:2176879,genotype,Dmd<mdx>/Dmd<mdx>; Utrn<tm1Ked>/Utrn<tm1Ked> [...,6
10233,ZP:0008070,phenotype,"skeletal muscle cell dystrophic, abnormal",8
10234,ClinVarVariant:577742,variant,NM_004006.2(DMD):c.8655C>A (p.Tyr2885Ter),10
10235,MP:0004087,phenotype,abnormal muscle fiber morphology,8


For the edges, load csv file.

In [11]:
edges = pd.read_csv('output/indexed_edges.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,FlyBase:FBgn0011676,Nos,5,9594,in orthology relationship with,ENSEMBL:ENSCAFG00000009820,ENSEMBL:ENSCAFG00000009820,5,7362,0
1,HGNC:12012,TPM3,5,1813,interacts with,HGNC:11946,TNNI2,5,7360,1
2,HGNC:15516,XYLT1,5,3228,causes condition,HP:0002650,Scoliosis,8,2102,2
3,WormBase:WBGene00000065,act-3,5,3272,in orthology relationship with,ENSEMBL:ENSOANG00000007850,ENSEMBL:ENSOANG00000007850,5,2284,0
4,HGNC:6485,LAMA5,5,4641,interacts with,HGNC:329,AGRN,5,2205,1
...,...,...,...,...,...,...,...,...,...,...
85949,dictyBase:DDB_G0276459,pakB,5,3746,in orthology relationship with,WormBase:WBGene00003911,pak-1,5,6906,0
85950,ZFIN:ZDB-GENE-030113-2,ttn.2,5,6355,causes condition,ZP:0106692,"Z disc myocardium disorganized, abnormal",8,6931,2
85951,ENSEMBL:ENSGALG00000015211,DTNA,5,828,in orthology relationship with,ENSEMBL:ENSECAG00000020912,ENSEMBL:ENSECAG00000020912,5,6757,0
85952,ENSEMBL:ENSECAG00000000207,ENSEMBL:ENSECAG00000000207,5,8915,in orthology relationship with,WormBase:WBGene00000064,act-2,5,5490,0


Load nodes with their features and edges into graph.

In [12]:
G = nx.Graph()
for ind, node in node_features.iterrows(): 
    G.add_node(int(node['Node']), node_feature=torch.Tensor(node['Embedding']))
for ind, edge in edges.iterrows(): 
    G.add_edge(int(edge['index_head']), int(edge['index_tail']))
  
for node in G.nodes(data=True):
    print(node)
    break
for edge in G.edges(data=True):
    print(edge)
    break

print("Number of edges is {}".format(G.number_of_edges()))
print("Number of nodes is {}".format(G.number_of_nodes()))

(0, {'node_feature': tensor([ 0.1352, -0.2536,  0.0798,  0.0451,  0.0747, -0.2546, -0.0705, -0.0646,
        -0.2445, -0.0202,  0.0493, -0.1636, -0.0623, -0.0508,  0.0288,  0.2002,
        -0.1173, -0.0943, -0.0744,  0.1843,  0.2936,  0.1871,  0.0121, -0.1656,
        -0.0181,  0.1075, -0.2460, -0.0466,  0.0871,  0.0957,  0.0087,  0.0946,
         0.0594, -0.0770, -0.0258,  0.0603, -0.0336, -0.0642,  0.0645,  0.0790,
        -0.0484, -0.0365, -0.0421,  0.0775,  0.1667, -0.0309,  0.0096, -0.1305,
        -0.0259,  0.1215,  0.0949,  0.0348,  0.1015,  0.1412,  0.0737,  0.0777,
         0.0288, -0.0531, -0.1087,  0.0693, -0.1133, -0.0443, -0.0622, -0.0670])})
(0, 5535, {})
Number of edges is 54994
Number of nodes is 10237


In [30]:
DeepG = Graph(G)
# is this needed?

print(DeepG.node_feature.shape)
print(DeepG.edge_index.shape)

torch.Size([10237, 64])
torch.Size([2, 109988])


Load arguments and parameters of trained GNN model.

In [32]:
with open('output/best_model_args.pkl', 'rb') as f:
    loaded_args = pickle.load(f)
    
class LinkPredModel(torch.nn.Module):
    """
        Architecture contains Batch Normalization layers (https://towardsdatascience.com/batch-norm-explained-visually-how-it-works-and-why-neural-networks-need-it-b18919692739) 
        between the SAGEConvolutional layers.
    """
    def __init__(self, input_size, hidden_size, out_size, num_layers, aggr, dropout, device):
        super(LinkPredModel, self).__init__()

        self.convs = torch.nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        
        self.convs.append(SAGEConv(input_size, hidden_size, normalize=True, aggr=aggr)) # input node embedding features, output is size of hidden layers
        self.bns.append(nn.BatchNorm1d(hidden_size))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_size, hidden_size, normalize=True, aggr=aggr))
            self.bns.append(nn.BatchNorm1d(hidden_size))
        self.convs.append(SAGEConv(hidden_size, out_size, normalize=True, aggr=aggr))
        self.bns.append(nn.BatchNorm1d(out_size))

        self.dropout = dropout
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        self.device = device


    def forward(self, x, edge_index, edge_label_index, training = True):
        x = x.to(self.device)
        edge_index = edge_index.to(self.device)
        edge_label_index = edge_label_index.to(self.device)

        for i in range(len(self.convs) - 1):
          x = self.convs[i](x, edge_index)
          x = self.bns[i](x)
          x = F.leaky_relu(x)
          x = F.dropout(x, p=self.dropout, training=training)
        
        x = self.convs[-1](x, edge_index)  
        x = self.bns[-1](x)
        

        nodes_first = torch.index_select(x, 0, edge_label_index[0,:].long())
        nodes_second = torch.index_select(x, 0, edge_label_index[1,:].long())
        pred = torch.sum(nodes_first * nodes_second, dim=-1)
        
        return pred, x
    
    def loss(self, pred, label):
        return self.loss_fn(pred, label)

num_node_features = len(node_features['Embedding'][0])

best_model = LinkPredModel(num_node_features, loaded_args['hidden_dim'], loaded_args['output_dim'], 
                           loaded_args['layers'], loaded_args['aggr'], loaded_args['dropout'], 
                           loaded_args['device']).to(loaded_args['device'])

best_model.load_state_dict(torch.load('output/best_model.pth'))

<All keys matched successfully>

In [19]:
_, best_x = best_model(DeepG.node_feature, DeepG.edge_index, DeepG.edge_label_index, training = False)
print(best_x.shape)

torch.Size([10237, 64])


# GNNExplainer

In [33]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw',
)

In [34]:
explainer = Explainer(
    model=best_model,
    explanation_type='model',
    algorithm=GNNExplainer(epochs=200),
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=model_config,
)

In [35]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

node_idx1 = get_node_idx('HP:0003560', nodes)
print('node1:')
print(nodes.loc[[node_idx1]])

node_idx2 = get_node_idx('522', nodes)
print('node2:')
print(nodes.loc[[node_idx2]])

node1:
             id   semantic               label  semantic_id
184  HP:0003560  phenotype  Muscular dystrophy            8
node2:
        id semantic       label  semantic_id
10137  522     drug  carvedilol            4


In [36]:
edge_label_index = torch.tensor([node_idx1, node_idx2])
edge_label_index

tensor([  184, 10137], dtype=torch.int32)

In [37]:
explanation = explainer(
    x=DeepG.node_feature,
    edge_index=DeepG.edge_index,
    edge_label_index=edge_label_index,
)

TypeError: '>' not supported between instances of 'tuple' and 'int'