In [1]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import pickle as pkl
from top_walks import *


In [None]:
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch']]

# Empirical comparision between $\widetilde K,  K$

In [None]:
dataset, model_dir = dataset_model_dirs[0]
graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
g = graphs[0]
pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
print(pred, g.label)
H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=None, mode='gamma')

In [None]:
lrp_rule = 'gamma'
H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=[3]*4, mode=lrp_rule)
init_rel = np.zeros_like(H)
init_rel[:, pred] = H[:, pred]

top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=200, lrp_mode="gamma", negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)


In [None]:
top_k_intersection_nb_dict = {}
np.random.seed(0)
graph_tp_idxs = []
graph_idxs = np.random.choice(len(graphs), 200)
for graph_idx in graph_idxs:
    g = graphs[graph_idx]
    pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
    if g.label == 0 and pred == 0 and g.nbnodes ** 4 >= 1000:
        graph_tp_idxs.append(graph_idx)
    if len(graph_tp_idxs) == 10: break

graphs_res = []
for graph_idx in tqdm(graph_tp_idxs):
    g = graphs[graph_idx]
    pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()

    lrp_rule = 'gamma'
    H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=[0.2]*4, mode=lrp_rule)
    init_rel = np.zeros_like(H)
    init_rel[:, pred] = H[:, pred]

    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=200, lrp_mode="gamma", negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
    top_tilde_K_order = sorted([(item[0], item[1]) for item in top_max_walks + top_min_walks], key=lambda x: abs(x[1]), reverse=True)

    neg_cnt = 0
    res = []
    for i in range(len(top_tilde_K_order)):
        if top_tilde_K_order[i][1] < 0: neg_cnt += 1
        res.append(neg_cnt / (i+1))
    
    graphs_res.append(res.copy())
np.array(graphs_res).mean(axis=0)

In [None]:
gamma0_res = [1.        , 1.        , 0.83333333, 0.825     , 0.84      ,
       0.78333333, 0.78571429, 0.8       , 0.78888889, 0.77      ,
       0.75454545, 0.74166667, 0.73076923, 0.72857143, 0.73333333,
       0.7375    , 0.71764706, 0.71111111, 0.72105263, 0.72      ,
       0.71904762, 0.71363636, 0.72173913, 0.71666667, 0.704     ,
       0.68461538, 0.67777778, 0.66428571, 0.65517241, 0.65666667,
       0.65483871, 0.65      , 0.64545455, 0.64117647, 0.63428571,
       0.62777778, 0.62702703, 0.62105263, 0.62051282, 0.62      ,
       0.61463415, 0.61428571, 0.61162791, 0.61136364, 0.61555556,
       0.6173913 , 0.6212766 , 0.625     , 0.62040816, 0.614     ,
       0.61176471, 0.60384615, 0.59622642, 0.59444444, 0.58909091,
       0.58035714, 0.57894737, 0.57586207, 0.57118644, 0.57      ,
       0.56557377, 0.56612903, 0.55873016, 0.553125  , 0.55384615,
       0.5530303 , 0.55074627, 0.55      , 0.55072464, 0.55      ,
       0.55211268, 0.55138889, 0.54931507, 0.54864865, 0.54533333,
       0.54473684, 0.54415584, 0.54358974, 0.54177215, 0.54      ,
       0.5382716 , 0.53414634, 0.53253012, 0.5297619 , 0.52941176,
       0.52906977, 0.53103448, 0.53181818, 0.53146067, 0.53      ,
       0.52967033, 0.53152174, 0.53225806, 0.53191489, 0.53263158,
       0.53229167, 0.53195876, 0.53265306, 0.53434343, 0.536     ,
       0.53861386, 0.54019608, 0.53980583, 0.5375    , 0.53904762,
       0.53962264, 0.54205607, 0.53981481, 0.53761468, 0.53545455,
       0.53603604, 0.53571429, 0.53539823, 0.53508772, 0.53565217,
       0.53448276, 0.53589744, 0.53813559, 0.53781513, 0.53916667,
       0.53884298, 0.53934426, 0.54065041, 0.54112903, 0.5432    ,
       0.54603175, 0.5480315 , 0.54921875, 0.54806202, 0.54461538,
       0.54198473, 0.54090909, 0.53984962, 0.5380597 , 0.53703704,
       0.5375    , 0.5379562 , 0.5384058 , 0.53884892, 0.53785714,
       0.53687943, 0.53732394, 0.53636364, 0.53611111, 0.53517241,
       0.53424658, 0.53129252, 0.53040541, 0.52885906, 0.52666667,
       0.52649007, 0.52763158, 0.52810458, 0.52727273, 0.52580645,
       0.525     , 0.52484076, 0.52531646, 0.5245283 , 0.52375   ,
       0.52173913, 0.52222222, 0.52269939, 0.52317073, 0.5230303 ,
       0.52409639, 0.52275449, 0.52261905, 0.52130178, 0.52058824,
       0.51988304, 0.51918605, 0.51791908, 0.51666667, 0.51428571,
       0.51306818, 0.51129944, 0.50955056, 0.50837989, 0.50777778,
       0.50718232, 0.50659341, 0.50601093, 0.5048913 , 0.50432432,
       0.50376344, 0.50374332, 0.5037234 , 0.5026455 , 0.50263158,
       0.50209424, 0.50104167, 0.49948187, 0.49742268, 0.49589744,
       0.49540816, 0.49593909, 0.4959596 , 0.49648241, 0.498     ]

In [None]:
gamma01_res = [0.2       , 0.2       , 0.33333333, 0.375     , 0.4       ,
       0.41666667, 0.38571429, 0.3625    , 0.35555556, 0.36      ,
       0.36363636, 0.375     , 0.38461538, 0.4       , 0.42      ,
       0.425     , 0.43529412, 0.45555556, 0.46315789, 0.47      ,
       0.48095238, 0.49545455, 0.50434783, 0.51666667, 0.52      ,
       0.52307692, 0.51851852, 0.51785714, 0.52413793, 0.54      ,
       0.5483871 , 0.55      , 0.54848485, 0.55      , 0.55428571,
       0.56388889, 0.56486486, 0.55789474, 0.55384615, 0.56      ,
       0.56341463, 0.56428571, 0.5627907 , 0.56136364, 0.56      ,
       0.55869565, 0.56170213, 0.55833333, 0.55714286, 0.552     ,
       0.55098039, 0.54807692, 0.54339623, 0.54444444, 0.54363636,
       0.53928571, 0.53684211, 0.53448276, 0.53389831, 0.53833333,
       0.53934426, 0.54193548, 0.54444444, 0.546875  , 0.55076923,
       0.5530303 , 0.55373134, 0.55588235, 0.55652174, 0.55571429,
       0.55492958, 0.55277778, 0.54931507, 0.5472973 , 0.54533333,
       0.54473684, 0.54285714, 0.53846154, 0.53544304, 0.53375   ,
       0.53209877, 0.53292683, 0.53373494, 0.53333333, 0.53529412,
       0.53604651, 0.53563218, 0.53409091, 0.53595506, 0.53888889,
       0.54175824, 0.54347826, 0.54516129, 0.54680851, 0.54947368,
       0.55208333, 0.55360825, 0.55510204, 0.55555556, 0.556     ,
       0.55544554, 0.55686275, 0.55728155, 0.55673077, 0.55809524,
       0.55660377, 0.55607477, 0.55555556, 0.55412844, 0.55181818,
       0.54954955, 0.54642857, 0.5460177 , 0.54385965, 0.5426087 ,
       0.54137931, 0.53846154, 0.53559322, 0.53361345, 0.53166667,
       0.52975207, 0.53032787, 0.53089431, 0.52983871, 0.5288    ,
       0.52698413, 0.52519685, 0.52578125, 0.5248062 , 0.52615385,
       0.5259542 , 0.525     , 0.52406015, 0.5238806 , 0.52296296,
       0.52279412, 0.52335766, 0.52246377, 0.52302158, 0.52357143,
       0.52198582, 0.52042254, 0.52027972, 0.51944444, 0.51931034,
       0.51917808, 0.51836735, 0.51689189, 0.5147651 , 0.514     ,
       0.51324503, 0.5125    , 0.51176471, 0.51038961, 0.5116129 ,
       0.51217949, 0.51210191, 0.51202532, 0.51194969, 0.5125    ,
       0.51180124, 0.5117284 , 0.51104294, 0.51097561, 0.51151515,
       0.51204819, 0.51077844, 0.50952381, 0.50887574, 0.50823529,
       0.50818713, 0.50872093, 0.50809249, 0.50862069, 0.50914286,
       0.50909091, 0.5079096 , 0.50617978, 0.50502793, 0.50277778,
       0.50110497, 0.49945055, 0.49726776, 0.49565217, 0.49405405,
       0.49247312, 0.49251337, 0.49202128, 0.49206349, 0.49210526,
       0.49162304, 0.49270833, 0.49274611, 0.49329897, 0.49282051,
       0.49183673, 0.49137056, 0.48989899, 0.48844221, 0.487     ]

In [None]:
gamma02_res = [0.4       , 0.3       , 0.3       , 0.3       , 0.32      ,
       0.31666667, 0.31428571, 0.3375    , 0.34444444, 0.34      ,
       0.33636364, 0.34166667, 0.33846154, 0.32857143, 0.33333333,
       0.35      , 0.35294118, 0.37222222, 0.37368421, 0.365     ,
       0.36666667, 0.37272727, 0.3826087 , 0.38333333, 0.392     ,
       0.4       , 0.41481481, 0.41428571, 0.42068966, 0.43      ,
       0.43225806, 0.434375  , 0.43636364, 0.44117647, 0.44      ,
       0.44722222, 0.44324324, 0.43421053, 0.42820513, 0.4225    ,
       0.41463415, 0.40952381, 0.40697674, 0.40681818, 0.40666667,
       0.40434783, 0.40212766, 0.39375   , 0.38979592, 0.394     ,
       0.4       , 0.40192308, 0.40188679, 0.40185185, 0.40545455,
       0.4125    , 0.41578947, 0.42068966, 0.4220339 , 0.42333333,
       0.42459016, 0.42419355, 0.42063492, 0.415625  , 0.41692308,
       0.41818182, 0.41940299, 0.42058824, 0.42028986, 0.42142857,
       0.42535211, 0.42638889, 0.4260274 , 0.42567568, 0.42666667,
       0.42763158, 0.42987013, 0.43076923, 0.43037975, 0.43375   ,
       0.43580247, 0.43658537, 0.4373494 , 0.43928571, 0.44117647,
       0.44186047, 0.44367816, 0.44204545, 0.44044944, 0.44      ,
       0.43846154, 0.43695652, 0.43333333, 0.43191489, 0.43157895,
       0.43020833, 0.42989691, 0.42755102, 0.42727273, 0.426     ,
       0.42673267, 0.42745098, 0.42718447, 0.42692308, 0.42857143,
       0.42830189, 0.42616822, 0.425     , 0.42385321, 0.42272727,
       0.42072072, 0.41964286, 0.41858407, 0.41929825, 0.42086957,
       0.42241379, 0.42222222, 0.4220339 , 0.42268908, 0.42416667,
       0.4231405 , 0.42377049, 0.42520325, 0.42580645, 0.4264    ,
       0.42777778, 0.42913386, 0.43046875, 0.42945736, 0.43      ,
       0.42900763, 0.4280303 , 0.42781955, 0.42686567, 0.42518519,
       0.42573529, 0.42408759, 0.42173913, 0.42086331, 0.41928571,
       0.41843972, 0.41760563, 0.41748252, 0.41666667, 0.41586207,
       0.41506849, 0.41428571, 0.41486486, 0.41610738, 0.41666667,
       0.41788079, 0.41776316, 0.41830065, 0.41818182, 0.41806452,
       0.41858974, 0.41910828, 0.41898734, 0.41698113, 0.415     ,
       0.41428571, 0.41358025, 0.41472393, 0.41463415, 0.41454545,
       0.41385542, 0.41257485, 0.4125    , 0.41242604, 0.41294118,
       0.4128655 , 0.4127907 , 0.41213873, 0.41206897, 0.41257143,
       0.41136364, 0.41016949, 0.40955056, 0.40837989, 0.40777778,
       0.40828729, 0.40824176, 0.40819672, 0.40869565, 0.40864865,
       0.40860215, 0.40909091, 0.4106383 , 0.41216931, 0.41315789,
       0.41308901, 0.41302083, 0.41295337, 0.41340206, 0.41538462,
       0.41683673, 0.4177665 , 0.41919192, 0.42060302, 0.4215    ]

In [None]:
gamma3_res = [0.3       , 0.3       , 0.33333333, 0.35      , 0.36      ,
       0.36666667, 0.37142857, 0.3625    , 0.36666667, 0.37      ,
       0.37272727, 0.36666667, 0.36923077, 0.37857143, 0.38      ,
       0.375     , 0.37647059, 0.38888889, 0.4       , 0.42      ,
       0.43809524, 0.45      , 0.46521739, 0.47083333, 0.468     ,
       0.46153846, 0.45555556, 0.46071429, 0.47241379, 0.48      ,
       0.48387097, 0.4875    , 0.49090909, 0.49705882, 0.50285714,
       0.50277778, 0.4972973 , 0.49210526, 0.49230769, 0.4875    ,
       0.48292683, 0.47380952, 0.46744186, 0.46136364, 0.45333333,
       0.44565217, 0.43617021, 0.43541667, 0.43265306, 0.428     ,
       0.42352941, 0.42115385, 0.41886792, 0.41666667, 0.41636364,
       0.41607143, 0.41578947, 0.4137931 , 0.41355932, 0.41333333,
       0.41311475, 0.41290323, 0.41269841, 0.4109375 , 0.41076923,
       0.41060606, 0.41044776, 0.40588235, 0.40434783, 0.40571429,
       0.4056338 , 0.40555556, 0.40547945, 0.40675676, 0.40666667,
       0.40526316, 0.40519481, 0.40769231, 0.41139241, 0.415     ,
       0.41851852, 0.42195122, 0.42650602, 0.4297619 , 0.43294118,
       0.43488372, 0.43678161, 0.43977273, 0.44044944, 0.44      ,
       0.44065934, 0.44130435, 0.44193548, 0.44468085, 0.44631579,
       0.44583333, 0.4443299 , 0.44285714, 0.44242424, 0.44      ,
       0.43663366, 0.43431373, 0.43203883, 0.42980769, 0.43047619,
       0.43113208, 0.4317757 , 0.43148148, 0.43119266, 0.43181818,
       0.43423423, 0.43571429, 0.43716814, 0.4377193 , 0.43826087,
       0.43793103, 0.43675214, 0.43644068, 0.43781513, 0.43833333,
       0.43966942, 0.44262295, 0.44471545, 0.44677419, 0.448     ,
       0.4484127 , 0.4488189 , 0.44921875, 0.4496124 , 0.45      ,
       0.45038168, 0.45      , 0.45037594, 0.45      , 0.45111111,
       0.45294118, 0.45474453, 0.45724638, 0.45755396, 0.45785714,
       0.45602837, 0.45492958, 0.45244755, 0.45069444, 0.44965517,
       0.45      , 0.44965986, 0.44932432, 0.44832215, 0.44666667,
       0.44503311, 0.44342105, 0.44248366, 0.44155844, 0.44193548,
       0.44230769, 0.44267516, 0.44113924, 0.43899371, 0.4375    ,
       0.43602484, 0.4345679 , 0.43251534, 0.42987805, 0.42787879,
       0.42590361, 0.4245509 , 0.42380952, 0.42307692, 0.42235294,
       0.42163743, 0.42034884, 0.41907514, 0.4183908 , 0.41771429,
       0.41761364, 0.41751412, 0.41685393, 0.41675978, 0.41666667,
       0.41657459, 0.41648352, 0.41639344, 0.41684783, 0.41783784,
       0.41774194, 0.41764706, 0.41861702, 0.41957672, 0.42052632,
       0.42094241, 0.42135417, 0.42124352, 0.42164948, 0.4225641 ,
       0.42295918, 0.42385787, 0.42424242, 0.42462312, 0.4255    ]

In [None]:
gamma10_res = [0.2       , 0.2       , 0.2       , 0.2       , 0.18      ,
       0.16666667, 0.15714286, 0.15      , 0.14444444, 0.14      ,
       0.13636364, 0.13333333, 0.13076923, 0.12857143, 0.12      ,
       0.11875   , 0.11764706, 0.11666667, 0.11578947, 0.11      ,
       0.10952381, 0.10454545, 0.10434783, 0.1       , 0.1       ,
       0.1       , 0.1       , 0.1       , 0.1       , 0.1       ,
       0.1       , 0.1       , 0.1       , 0.10294118, 0.10571429,
       0.10277778, 0.10810811, 0.10789474, 0.10769231, 0.1075    ,
       0.10731707, 0.10714286, 0.10697674, 0.10681818, 0.10888889,
       0.10869565, 0.1106383 , 0.1125    , 0.11632653, 0.118     ,
       0.11960784, 0.11923077, 0.12075472, 0.12037037, 0.12181818,
       0.12321429, 0.1245614 , 0.12758621, 0.12881356, 0.13166667,
       0.13278689, 0.13548387, 0.13650794, 0.1375    , 0.14      ,
       0.14242424, 0.14477612, 0.14558824, 0.14927536, 0.15      ,
       0.15070423, 0.15138889, 0.15068493, 0.15      , 0.14933333,
       0.14736842, 0.14675325, 0.14487179, 0.14303797, 0.1425    ,
       0.14074074, 0.1402439 , 0.13975904, 0.13809524, 0.13647059,
       0.13488372, 0.13448276, 0.13409091, 0.13370787, 0.13222222,
       0.13186813, 0.13043478, 0.12903226, 0.12765957, 0.12631579,
       0.12604167, 0.1257732 , 0.12755102, 0.12929293, 0.13      ,
       0.13069307, 0.13235294, 0.13398058, 0.13557692, 0.13714286,
       0.14056604, 0.14392523, 0.1462963 , 0.14862385, 0.15090909,
       0.15135135, 0.15267857, 0.15486726, 0.15614035, 0.15826087,
       0.15948276, 0.15982906, 0.16271186, 0.16554622, 0.16916667,
       0.17190083, 0.17295082, 0.17479675, 0.17580645, 0.1768    ,
       0.17777778, 0.17952756, 0.18203125, 0.18449612, 0.18615385,
       0.1870229 , 0.18863636, 0.1887218 , 0.18955224, 0.19037037,
       0.19117647, 0.1919708 , 0.19202899, 0.19280576, 0.19357143,
       0.19503546, 0.19577465, 0.1972028 , 0.19722222, 0.19724138,
       0.19794521, 0.19931973, 0.19932432, 0.20067114, 0.20133333,
       0.20331126, 0.20526316, 0.20653595, 0.20714286, 0.20774194,
       0.20897436, 0.2089172 , 0.20886076, 0.20880503, 0.208125  ,
       0.20745342, 0.20679012, 0.20674847, 0.20670732, 0.20727273,
       0.20783133, 0.20778443, 0.20714286, 0.20710059, 0.20647059,
       0.20584795, 0.20465116, 0.20404624, 0.20287356, 0.20228571,
       0.20170455, 0.20112994, 0.2005618 , 0.2       , 0.2       ,
       0.19889503, 0.19835165, 0.19726776, 0.19619565, 0.19513514,
       0.19462366, 0.19411765, 0.19361702, 0.19312169, 0.19210526,
       0.19109948, 0.19010417, 0.19015544, 0.18969072, 0.18923077,
       0.18928571, 0.18883249, 0.18838384, 0.1879397 , 0.188     ]

In [None]:
gammainf_res = [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]

In [None]:
fig = plt.figure(figsize=(5,3))

plt.plot(np.arange(1,201), 1 - np.array(gamma0_res), 'r-')
plt.plot(np.arange(1,201), 1 - np.array(gamma02_res), 'g--')
plt.plot(np.arange(1,201), 1 - np.array(gamma3_res), 'b-.')
plt.plot(np.arange(1,201), 1 - np.array(gammainf_res), 'y-.')

plt.legend([r'$\gamma=0$', r'$\gamma=0.2$', r'$\gamma=3$', r'$\gamma\rightarrow +\infty$'])
plt.xlim(1,201)
plt.ylim(0,1.01)
plt.ylabel(r"${K}/{\widetilde K}$", rotation=90)
plt.xlabel(r"$\widetilde K$")

plt.savefig('imgs/tilde_K_K.svg', dpi=600, format='svg',bbox_inches='tight')