In [None]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl
from top_walks import *
from utils import *
from IPython.display import SVG, display
from baseline_comp import *
import itertools

In [None]:
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch']]

In [None]:
dataset, model_dir = dataset_model_dirs[6]
graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
top_k_intersection_nb_dict = {}
np.random.seed(0)
graph_tp_idxs = []
graph_idxs = np.random.choice(len(graphs), 200)
for graph_idx in graph_idxs:
    g = graphs[graph_idx]
    pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
    if g.label == 0 and pred == 0 and g.nbnodes ** 4 >= 1000:
        graph_tp_idxs.append(graph_idx)
    if len(graph_tp_idxs) == 50: break

for tag in ['ab', 'zero', 'gamma', '02']:
    graphs_res = []
    for graph_idx in (graph_tp_idxs):
        g = graphs[graph_idx]
        pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()

        res = []
        if tag == 'zero':
            lrp_rule = 'none'
            H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=[0]*4, mode=lrp_rule)
        elif tag == 'gamma':
            H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=None, mode='gamma')
        elif tag == '02':
            H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=[0.2]*4, mode='gamma')
        elif tag == 'ab':
            H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=[0.2]*4, mode='clip-pos')
        init_rel = np.zeros_like(H)
        init_rel[:, pred] = H[:, pred]
        
        walk_rels = {}
        for walk in tqdm(itertools.product(np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes))):
            rel = walk_rel(transforms, init_rel, walk, mode="node")
            walk_rels[tuple(walk)] = rel
        sorted_walk_rels = sorted(walk_rels.items(), key=lambda item: item[1], reverse=True)

        top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=200, lrp_mode="gamma", negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel, how_max='none')
        res.append([top_max_walks + top_min_walks, sorted_walk_rels[:200]])
        
        graphs_res.append(res)

    pkl.dump(graphs_res, open(f"results/{dataset}_top_k_walk_NODE_raw_{tag}.pkl","wb"))


# Plot

In [None]:
def process_data(graphs_res, K=200):
    pres_l = []
    recall_l = []

    for gres in tqdm(graphs_res):
        pres = []
        recall = []
        approx_top_k = sorted(gres[0][0][:K], key=lambda x: x[1], reverse=True)
        real_top_k = gres[0][1]
        set_real_top_k = set([item[0] for item in real_top_k[:len(approx_top_k)]])
        if real_top_k[len(approx_top_k)-1][1] == real_top_k[len(approx_top_k)-2][1]:
            # If same relevant walks in top-K* set, add all same rel walks to it
            for idx in range(len(approx_top_k), len(real_top_k)):
                if real_top_k[idx][1] == real_top_k[len(approx_top_k)-1][1]: set_real_top_k.add(real_top_k[idx][0])
                else: break

        for k in range(1,K+1):
            pres.append(len(set([item[0] for item in approx_top_k[:k]]).intersection(set_real_top_k)) / min(k, len(approx_top_k)))    
            recall.append(len(set([item[0] for item in approx_top_k[:k]]).intersection(set_real_top_k)) / min(K, len(approx_top_k)))
        pres_l.append(pres.copy())
        recall_l.append(recall.copy())
    return pres_l, recall_l

In [None]:
fig = plt.figure(figsize=(8,4))
for i, K in enumerate([5,25,100,200]):
    plt.subplot(2,4,1+i)
    styles = {'ab': 'r-', 'gamma': 'b-.', 'zero': 'g--', '02': 'y-'}
    tags = ['zero', '02', 'gamma', 'ab']
    dataset = 'BA-2motif'
    # dataset = 'MUTAG'
    for tag in tags:
        graphs_res = pkl.load(open(f"results/{dataset}_top_k_walk_NODE_raw_{tag}.pkl","rb"))
        pres_l, recall_l = process_data(graphs_res, K=K)
        plt.plot(np.array(recall_l).mean(axis=0), np.array(pres_l).mean(axis=0), styles[tag])

    if i == 0: 
        plt.ylabel(f"Precision\n{dataset}")
    else:
        plt.yticks([])

    if i == 3:
        plt.legend([r'$\gamma=0$', r'$\gamma = 0.2$', r'$\gamma=[3,\cdots,0]$', r'$\gamma \rightarrow \infty$'], loc='lower center', bbox_to_anchor=(-1.5, -1.85), ncol=4)
    plt.xlim(0,1)
    plt.ylim(-0.01,1.01)
    plt.xticks([])
    plt.title(r"$K^*$ = "+str(K))

    plt.subplot(2,4,5+i)
    dataset = 'Mutagenicity'
    # dataset = 'Graph_SST2'
    for tag in tags:
        graphs_res = pkl.load(open(f"results/{dataset}_top_k_walk_NODE_raw_{tag}.pkl","rb"))
        pres_l, recall_l = process_data(graphs_res, K=K)
        plt.plot(np.array(recall_l).mean(axis=0), np.array(pres_l).mean(axis=0), styles[tag])


    if i == 0:
        plt.ylabel(f"Precision\n{dataset}")
    else:
        plt.yticks([])
    plt.xlabel('Recall')
    plt.xlim(0,1)
    plt.ylim(-0.01,1.01)
    

plt.savefig('imgs/precision_recall_topk_walks_AMP_2x4.svg', dpi=600, format='svg',bbox_inches='tight')
