This Jupyter notebook contains the GNNExplainer method.

In [24]:
import os

import pandas as pd
import networkx as nx
import numpy as np
import torch

import pickle
import copy

import matplotlib.pyplot as plt

from gensim.models import KeyedVectors

from collections import Counter

# Set Parameters

In [36]:
dataset_nr = 1
assert dataset_nr == 1 or 2

embedding_method = 'e2v'
assert embedding_method == 'e2v' or 'm2v'

seeded_emb = False

if seeded_emb:
    fixed_emb = '_seeded'
else:
    fixed_emb = ''
    
if dataset_nr == 1 and not seeded_emb:
    threshold_nr_runs = '7'
else:
    threshold_nr_runs = 'all'

# Set Folder

In [26]:
curr_working_dir = os.getcwd()
curr_output_dir = os.path.join(curr_working_dir, 'output')
dataset_output_dir = os.path.join(curr_output_dir, f'g{dataset_nr}_{embedding_method}{fixed_emb}')

if not os.path.exists(dataset_output_dir):
    print('First, run the edge2vec embedding and predictor script. Then, run this script.')
else:
    print(f'Output folder for dataset {dataset_nr} already exists and will be used: {dataset_output_dir}')
    
run_folders_list = []
for item in os.listdir(dataset_output_dir):
    curr_path = os.path.join(dataset_output_dir, item)
    if os.path.isdir(curr_path) and 'run' in item:
        run_folders_list.append(item)
        
expl_output_dir = os.path.join(dataset_output_dir, 'expl')
if not os.path.exists(expl_output_dir):
    os.mkdir(expl_output_dir)
    print(f'Output folder for predictions from dataset {dataset_nr} using method {embedding_method} is created: {expl_output_dir}')
else:
    print(f'Output folder for predictions from dataset {dataset_nr} using method {embedding_method} already exists and will be used: {expl_output_dir}')

run_folders_paths = []
if len(run_folders_list) > 0:
    for run_folder in run_folders_list:
        latest_run = run_folder
        
        run_dir = os.path.join(dataset_output_dir, latest_run)
        run_folders_paths.append(run_dir)
    
else:
    print('First, run the edge2vec embedding and predictor script. Then, run this script.')

Output folder for dataset 1 already exists and will be used: c:\Users\rosa-\Google Drive\Msc_Bioinformatics\thesis\XAIFO-ThesisProject\output\g1_e2v
Output folder for predictions from dataset 1 using method e2v already exists and will be used: c:\Users\rosa-\Google Drive\Msc_Bioinformatics\thesis\XAIFO-ThesisProject\output\g1_e2v\expl


# Load all data

Load the nodes

In [27]:
nodes = pd.read_csv(f'output/indexed_nodes_{dataset_nr}.csv')
nodes

Unnamed: 0,index_id,id,semantic,label,semantic_id
0,0,WormBase:WBGene00000389,ORTH,cdc-25.4,5
1,1,ZP:0018675,DISO,right side lateral plate mesoderm mislocalised...,1
2,2,ZFIN:ZDB-GENE-040426-1197,ORTH,tbc1d5,5
3,3,5,DRUG,(S)-nicardipine,2
4,4,RGD:3443,ORTH,Ptk2,5
...,...,...,...,...,...
10029,10029,MP:0009763,DISO,increased sensitivity to induced morbidity/mor...,1
10030,10030,MP:0011057,DISO,absent brain ependyma motile cilia,1
10031,10031,MP:0001412,DISO,excessive scratching,1
10032,10032,WBPhenotype:0004023,DISO,frequency of body bend variant,1


In [28]:
nodes.iloc[5487]

index_id            5487
id             HGNC:6717
semantic            GENE
label              LTBP4
semantic_id            3
Name: 5487, dtype: object

In [29]:
node_labels_dict = nodes[['semantic_id', 'semantic']].drop_duplicates().set_index('semantic_id').to_dict()['semantic']
node_labels_dict

{5: 'ORTH',
 1: 'DISO',
 2: 'DRUG',
 4: 'GENO',
 7: 'VARI',
 3: 'GENE',
 0: 'ANAT',
 6: 'PHYS'}

Load the edges

In [30]:
edges = pd.read_csv(f'output/indexed_edges_{dataset_nr}.csv')
edges

Unnamed: 0,head,label_head,class_head,index_head,relation,tail,label_tail,class_tail,index_tail,type
0,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,FlyBase:FBgn0085464,CG34435,5,6825,0
1,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,HGNC:7585,MYL4,3,27,0
2,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,FlyBase:FBgn0002772,Mlc1,5,8901,0
3,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in orthology relationship with,NCBIGene:396472,MYL4,3,9508,0
4,ZFIN:ZDB-GENE-050626-112,myl4,5,5279,in 1 to 1 orthology relationship with,ENSEMBL:ENSECAG00000020967,ENSEMBL:ENSECAG00000020967,5,8807,1
...,...,...,...,...,...,...,...,...,...,...
82908,4810,ibrutinib,2,1618,targets,HGNC:11283,SRC,3,3279,14
82909,522,carvedilol,2,184,targets,HGNC:620,APP,3,547,14
82910,OMIM:300377.0013,"DMD, EX18DEL",1,2822,is allele of,HGNC:2928,DMD,3,6612,16
82911,Coriell:GM05113,NIGMS-GM05113,4,8105,has role in modeling,MONDO:0010679,Duchenne muscular dystrophy,1,6315,15


In [31]:
def loadEdge2VecEmbedding(current_run_dir): 
    node_feat = KeyedVectors.load(f'{current_run_dir}/w2v_{dataset_nr}.dvectors', mmap='r')
    e2v_embedding = pd.DataFrame(columns = ['Node', 'Embedding'])
    for idx, key in enumerate(node_feat.index_to_key):
        e2v_embedding.loc[int(key)] = pd.Series({'Node':int(key), 'Embedding':list(node_feat[key])})
        
    return e2v_embedding.sort_values('Node')

def loadMetapath2VecEmbedding(current_run_dir):
    metapath2vec_embedding = pd.read_csv(f'{current_run_dir}/metapath2vec_embedding_{dataset_nr}.csv')
    metapath2vec_embedding['Embedding'] = metapath2vec_embedding.values.tolist()
    metapath2vec_embedding = metapath2vec_embedding[['Embedding']]
    metapath2vec_embedding['Node'] = metapath2vec_embedding.index
    return metapath2vec_embedding[['Node', 'Embedding']]

In [32]:
def create_graph(embedding):
    G = nx.Graph() # TODO: DiGraph?
    for ind, node in embedding.iterrows(): 
        G.add_node(int(node['Node']), node_feature=torch.Tensor(node['Embedding']))
    for ind, edge in edges.iterrows(): 
        G.add_edge(int(edge['index_head']), int(edge['index_tail']), edge_label = edge['relation'])
    
    for node in G.nodes(data=True):
        print(node)
        break
    for edge in G.edges(data=True):
        print(edge)
        break

    print("Number of edges is {}".format(G.number_of_edges()))
    print("Number of nodes is {}".format(G.number_of_nodes()))
    
    return G

In [33]:
from gnn.linkpred_model import LinkPredModel

def load_trained_model(current_run_dir, embedding):
    with open(f'{current_run_dir}/best_model_{dataset_nr}_{embedding_method}_args.pkl', 'rb') as f:
        loaded_args = pickle.load(f)
        
    best_model = LinkPredModel(len(embedding['Embedding'][0]), 
                            loaded_args['hidden_dim'], loaded_args['output_dim'], 
                            loaded_args['layers'], loaded_args['aggr'], 
                            loaded_args['dropout'], loaded_args['device']).to(loaded_args['device'])
    best_model.load_state_dict(torch.load(f'{run_dir}/best_model_{dataset_nr}_{embedding_method}.pth'))
    
    return best_model

# Explain predictions

In [34]:
from gnn.gnnexplainer import GNNExplainer, visualize_subgraph

num_hops = 1

def explain_edge(node_idx1, node_idx2, model, x, edge_index, G):
    explainer = GNNExplainer(model,
                             epochs=700, num_hops=num_hops, lr=0.01)
    
    trigger = False
    early_stop = 0
    size = 20   # change size of the explanation graph TODO: 15
    iterations = 150 # number of times GNNExplainer is executed
    
    while not trigger:
        _, edge_mask = explainer.explain_link(node_idx1=node_idx1, node_idx2=node_idx2,
                                              x=x, edge_index=edge_index,
                                              G=G)
        
        G2 = copy.deepcopy(G)
        
        if size is not None:
            limit = edge_mask.sort(descending = True)[0][size]
            print('Contribution threshold is', limit)
        else:
            limit = 0.5
        
        for indx, edge in enumerate(G.edges): 
            if edge_mask[indx] < limit:
                G2.remove_edge(edge[0], edge[1])
        
        if nx.has_path(G2, node_idx1, node_idx2) and torch.is_nonzero(limit):
            trigger = True
            
        early_stop += 1
        
        if early_stop == iterations and not trigger:
            print('No good explanation found after {} iterations'.format(early_stop))
            trigger = True
        elif trigger: 
            print('Explanation found!')
            return edge_mask, limit
        
    return None, None
    

In [39]:
with open(f'{dataset_output_dir}/symptom_drug_pair_overlapping_{threshold_nr_runs}_runs_{dataset_nr}_{embedding_method}{fixed_emb}.pkl', 'rb') as f:
    overlapping_symptom_drug_pairs = pickle.load(f)

overlapping_symptoms_drugs = {}

for overlapping_pair in overlapping_symptom_drug_pairs:
    symptom = overlapping_pair[0]
    drug = overlapping_pair[1]
    
    if symptom in overlapping_symptoms_drugs:
        overlapping_symptoms_drugs[symptom].append(drug)
    else:
        overlapping_symptoms_drugs[symptom] = [drug]
        
overlapping_symptoms_drugs

{'HP:0011675': ['231', '1576'],
 'HP:0001638': ['1576'],
 'HP:0001635': ['231', '1576'],
 'HP:0002791': ['5345'],
 'HP:0003236': ['1576'],
 'HP:0003323': ['1576'],
 'HP:0003307': ['1576'],
 'HP:0001644': ['1576'],
 'HP:0002650': ['5252']}

In [40]:
def get_occurrence_df(c, label_name):
    c_perc = [{label_name: i, 'Percentage': c[i] / c.total() * 100.0, 'Appearances': c[i]} for i in c]
    c_df = pd.DataFrame.from_dict(c_perc)
    c_df['Percentage'] = c_df['Percentage'].transform(lambda x: '{:,.2f}%'.format(x))
    c_df = c_df.sort_values(by = ['Appearances'], ascending = False)
    return c_df

def count_occurrences(G):    
    node_types = []
    for n, attr in G.nodes(data=True):
        node_label = attr['y2']
        node_types.append(node_labels_dict[node_label])
    
    edge_types = []
    metapath_triplets = []
    for n1, n2, attr in G.edges(data=True):
        edge_label = attr['label']
        edge_types.append(edge_label)
        
        n1_class = G.nodes[n1]['y2']
        n2_class = G.nodes[n2]['y2']
        
        metapath = tuple([node_labels_dict[n1_class], edge_label, node_labels_dict[n2_class]])
        metapath_triplets.append(metapath)
        
     
    node_types_df = get_occurrence_df(Counter(node_types), 'Node Type')
    print(node_types_df)
    
    edge_types_df = get_occurrence_df(Counter(edge_types), 'Edge Type')
    print(edge_types_df)
    
    metapaths_df = get_occurrence_df(Counter(metapath_triplets), 'Metapath')
    print(metapaths_df)
    
    return node_types_df, edge_types_df, metapaths_df

In [41]:
def get_graph_to_save(G):
    G_save = copy.deepcopy(G)

    for node in G_save.nodes(data=True):
        key = node[1]['y2']
        node[1]['type'] = node_labels_dict[key]
        
    for node in G_save.nodes():
        G_save.nodes[node].pop('y')
        G_save.nodes[node].pop('y2')
        
    for edge in G_save.edges():
        G_save.edges[edge].pop('att')
        G_save.edges[edge].pop('edge_color')

    for node in G_save.nodes(data=True):
        print(node)
    for edge in G_save.edges(data=True):
        print(edge)
        
    return G_save

def visualize_explanation(nr, explanation, edge_index, edge_labels_dict):
    nodes_idxs = torch.Tensor([explanation['node_idx1'], explanation['node_idx2']]).long()
    
    drug_id = explanation['drug']
    drug_name = nodes.loc[nodes['id'] == drug_id]['label'].iloc[0]
    
    symptom_id = explanation['symptom']
    symptom_name = nodes.loc[nodes['id'] == symptom_id]['label'].iloc[0]
    
    run_name = explanation['run']
    
    explanation_title = f"Explanation for link between {drug_name} ({drug_id}) and {symptom_name} ({symptom_id}) from run {run_name}"

    plt.figure(figsize=(10, 10))
    plt.title(explanation_title)
    ax, G_sub = visualize_subgraph(nodes_idxs, edge_index, explanation['found_edge_mask'], 
                                   nodes=nodes, node_labels_dict=node_labels_dict, y=torch.Tensor(nodes.semantic_id), 
                                   seed=667, num_hops=num_hops, 
                                   threshold=explanation['found_limit'], 
                                   node_label='label', edge_labels=edge_labels_dict, 
                                   show_inactive=False, remove_unconnected=True)
    
    plt.savefig(f'{expl_output_dir}/explanation_{nr}_{run_name}.png', bbox_inches='tight')
    plt.show()
    
    # See https://networkx.org/documentation/stable/release/migration_guide_from_2.x_to_3.0.html#deprecated-code
    with open(f'{expl_output_dir}/explanation_{nr}_{run_name}_graph.gpickle', 'wb') as f:
        pickle.dump(get_graph_to_save(G_sub), f, pickle.HIGHEST_PROTOCOL)
        
    print(f'Explanation graph saved at {expl_output_dir}/explanation_{nr}_{run_name}_graph.gpickle')

In [None]:
def get_node_idx(id, nodes): 
    return nodes[nodes['id'] == id].index.values.astype(int)[0]

def is_overlapping_pair(symptom_id, drug_id):
    if symptom_id in overlapping_symptoms_drugs:
        if drug_id in overlapping_symptoms_drugs[symptom_id]:
            return True
    return False

found_explations_list = []

for run_dir in run_folders_paths:
    print(run_dir)
    run_name = run_dir.split('\\')[-1]
    
    symptoms_drugs = pd.read_pickle(f'{run_dir}/pred/candidates_per_symptom_{dataset_nr}_{embedding_method}.pkl')   
    
    if embedding_method == 'e2v':
        embedding = loadEdge2VecEmbedding(current_run_dir=run_dir)
    else:
        embedding = loadMetapath2VecEmbedding(current_run_dir=run_dir)
        
    G = create_graph(embedding=embedding)
    edge_labels_dict = dict([((n1, n2), G.edges[(n1,n2)]['edge_label']) for n1, n2 in G.edges])
    
    best_model = load_trained_model(current_run_dir=run_dir, embedding=embedding)
    
    x = torch.Tensor(embedding['Embedding'])
    edge_index = torch.Tensor(np.array(G.edges).transpose()).type(torch.int64).long()
    
    explanation_nr = 0
    
    for ind, edge in symptoms_drugs.iterrows(): 
        symptom_id = edge['Symptom']
        for drug_id in edge['Candidates']:
            if is_overlapping_pair(symptom_id=symptom_id, drug_id=drug_id):
            
                node_idx1 = get_node_idx(symptom_id, nodes)
                print('node1:')
                print(nodes.loc[[node_idx1]])

                node_idx2 = get_node_idx(str(drug_id), nodes)
                print('node2:')
                print(nodes.loc[[node_idx2]])
                
                found_edge_mask, found_limit = explain_edge(node_idx1=node_idx1, node_idx2=node_idx2,
                                                            model=best_model, x=x, edge_index=edge_index, G=G)
                
                if found_edge_mask is not None and found_limit is not None:
                    explanation_nr += 1
                    
                    found_explanation= {'symptom': symptom_id, 'drug': drug_id, 
                                        'node_idx1': node_idx1, 'node_idx2': node_idx2, 
                                        'found_edge_mask': found_edge_mask, 'found_limit': found_limit, 
                                        'run': run_name}
                    
                    found_explations_list.append(found_explanation)
                    
                    visualize_explanation(explanation_nr, found_explanation, edge_index, edge_labels_dict)
    