In [None]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl
from top_walks import *
from utils import *
from IPython.display import SVG, display
from baseline_comp import *
import itertools

In [None]:
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch']]

In [None]:
dataset, model_dir = dataset_model_dirs[0]
graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
drop_selfloop = True

np.random.seed(1)
graph_idxs = np.random.choice(len(graphs), min(10000, len(graphs)), replace=False)
use_softmax = True
model_depth = int(model_dir.split('-')[1])
log_probability = False
use_max = False
K = 1000

res = []
for graph_idx in tqdm(graph_idxs):
    g = graphs[graph_idx]
    if drop_selfloop and g.nbnodes == 1: continue
    
    if use_softmax:
        prob = F.softmax(nn.forward(g.get_adj(), g.node_features), dim=0)
        pred = int(prob.argmax())
        p1 = float(prob[pred])
    else:
        prob = nn.forward(g.get_adj(), g.node_features)
        pred = int(prob.argmax())
        p1 = float(prob[pred])

    # if g.label == 1 or pred == 1: continue

    lrp_rule = 'gamma'
    rel_components, X_fc, A = get_rel_components(g, nn, lrp_rule=lrp_rule)
    top_k_max_walks_rels, _ = approx_top_k_walk(K, g, nn, rel_components, X_fc, A, verbose=False)
    top_k_max_walks_rels = top_k_max_walks_rels + _
    edges_list = []
    for item in top_k_max_walks_rels:
        walk = item[0]
        for i in range(len(walk)-1):
            edge = tuple([walk[i], walk[i+1]])
            if drop_selfloop and edge[0] == edge[1]: continue
            if edge not in edges_list:
                edges_list.append(edge)

    edge_num = len(edges_list)


    R = gnnexplainer(g,nn,H0=g.node_features,steps=500,lr=0.5,lambd=0.01,verbose=False)
    ind = np.unravel_index(np.argsort(R, axis=None), R.shape)
    edges = np.array(ind).T[::-1]
    edges_gnnexpl = []
    for edge in edges:
        if not drop_selfloop: edges_gnnexpl.append(edge)
        else:
            if edge[0] != edge[1]: edges_gnnexpl.append(edge)
        if len(edges_gnnexpl) == edge_num: break

    edges_edge_ig = get_top_edges_edge_ig(nn, g, pred, edge_num, drop_selfloop)
    edges_gnnlrp_edge = gnnlrp_edge_rel(nn, g, edge_num, drop_selfloop)

    gt = set()
    gt_nodes = np.arange(20,25)
    for edge in g.get_adj().nonzero().tolist():
        if edge[0] != edge[1] and edge[0] in gt_nodes and edge[1] in gt_nodes:
            gt.add(tuple(edge))
        

    fid = []
    fid_gnnexpl = []
    fid_edge_ig = []
    fid_gnnlrp_edge = []

    amp_edge_set = set()
    gnnexpl_edge_set = set()
    edgeig_edge_set = set()
    gnnlrp_edge_set = set()
    for i in range(min(edge_num, len(edges_gnnexpl), len(edges_edge_ig), len(edges_gnnlrp_edge))):
        edge = edges_list[i]
        amp_edge_set.add(tuple(edge))

        edge = edges_gnnexpl[i]
        gnnexpl_edge_set.add(tuple(edge))

        edge = edges_edge_ig[i]
        edgeig_edge_set.add(tuple(edge))

        edge = edges_gnnlrp_edge[i]
        gnnlrp_edge_set.add(tuple(edge))
        
        pres_amp = len(gt.intersection(amp_edge_set)) / (i+1)
        pres_gnnexpl = len(gt.intersection(gnnexpl_edge_set)) / (i+1)
        pres_edge_ig = len(gt.intersection(edgeig_edge_set)) / (i+1)
        pres_gnnlrp_edge = len(gt.intersection(gnnlrp_edge_set)) / (i+1)

        recall_amp = 1 - len(gt - gnnexpl_edge_set) / len(gt)
        recall_gnnexpl = 1 - len(gt - gnnexpl_edge_set) / len(gt)
        recall_edge_ig = 1 - len(gt - edgeig_edge_set) / len(gt)
        recall_gnnlrp_edge = 1 - len(gt - gnnlrp_edge_set) / len(gt)

        res.append([pres_amp,
                    pres_gnnexpl,
                    pres_edge_ig,
                    pres_gnnlrp_edge,
                    recall_amp,
                    recall_gnnexpl,
                    recall_edge_ig,
                    recall_gnnlrp_edge,
                    i+1, pred, int(g.label)])
result = pd.DataFrame(res, columns=['AMP-ave_pres', 'gnnexpl_pres', 'edgeig_pres', 'gnnlrp-edge_pres', 
                                    'AMP-ave_recall', 'gnnexpl_recall', 'edgeig_recall', 'gnnlrp-edge_recall', 
                                    'edge_num', 'pred', 'label'])
result.to_csv(f"results/EDGE_GROUNDTRUTH_BA-2motif.csv", index=False)

In [None]:
result = pd.read_csv(f"results/EDGE_GROUNDTRUTH_BA-2motif.csv")

fig = plt.figure(figsize=(5,3))
x = np.arange(1,51)

plt.subplot(211)
pos_ave = result[(result['pred']==0) & (result['label']==0)].groupby('edge_num').mean().drop(['pred', 'label'], axis=1)
plt.plot(x, pos_ave['AMP-ave'], 'b-.')
plt.plot(x, pos_ave['gnnexpl'], '--')
plt.plot(x, pos_ave['edgeig'], '-+g')
plt.plot(x, pos_ave['gnnlrp-edge'], '-', color='orange')
plt.legend(['AMP-ave (ours)', 'GNNExplainer', 'Edge-IG', 'GNN-LRP edge-lvl'], loc='lower center', bbox_to_anchor=(0.75, -0.25), ncol=1)

plt.grid()
plt.xlim(1,50)
plt.xticks([1,10,20,30,40,50], ['','','','','',''])
plt.ylim(0,1.01)
plt.ylabel('Edge Recall \n positive')

plt.subplot(212)
pos_ave = result[(result['pred']==1) & (result['label']==1)].groupby('edge_num').mean().drop(['pred', 'label'], axis=1)
plt.plot(x, pos_ave['AMP-ave'][:50], 'b-.')
plt.plot(x, pos_ave['gnnexpl'][:50], '--')
plt.plot(x, pos_ave['edgeig'][:50], '-+g')
plt.plot(x, pos_ave['gnnlrp-edge'][:50], '-', color='orange')
plt.grid()
plt.xticks([1,10,20,30,40,50])
plt.xlim(1,50)
plt.ylim(0,1.01)
plt.ylabel('Edge Recall \n negative')
plt.xlabel('# of edges')

plt.savefig("imgs/ba2motif_edge_detection.svg", dpi=600, format='svg',bbox_inches='tight')