In [1]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl
from top_walks import *
from utils import *
from IPython.display import SVG, display
from baseline_comp import *
import itertools

# BA2motif

In [None]:
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch']]
dataset, model_dir = dataset_model_dirs[0]
graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

## As a function of K, L

In [None]:
time_res = []
for L in range(2,8):
    model_dir = f"gin-{L}-ba2motif.torch"
    nn = torch.load('models/'+model_dir)

    g = graphs[0]
    pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
    H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=None, mode='gamma')
    init_rel = np.zeros_like(H)
    init_rel[:, pred] = H[:, pred]

    time_accumulate_1 = 0
    time_accumulate_2 = 0
    iteration_num = 5
    for _ in tqdm(range(iteration_num)):
        time_a = time.time()

        walk_rels = {}
        for i, walk in enumerate(itertools.product(np.arange(g.nbnodes),repeat=L+1)):
            if i == 1000: break
            rel = walk_rel(transforms, init_rel, walk, mode="node")
            walk_rels[tuple(walk)] = rel
        
        time_tmp = time.time() - time_a
        time_tmp = time_tmp / 1000 * (g.nbnodes ** (L+1))

        time_a = time.time()
        top_walk = max(walk_rels.items(), key=lambda item: item[1])[0]
        time_accumulate_1 += time_tmp + time.time() - time_a

        time_a = time.time()
        top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=1, lrp_mode="gamma", 
                negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
        time_accumulate_2 += time.time() - time_a

    time_res.append([L, time_accumulate_1/iteration_num, time_accumulate_2/iteration_num])

In [None]:
# plotting
num_layers = np.arange(2,8)
model_times = np.array(time_res)

fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3.5,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

plt.rc('legend', fontsize=14.5) 
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax1.set_ylabel("Time (s)")
ax2.set_xlabel(r'$L$')
plt.xticks(num_layers, [str(i) if i % 2 == 1 else '' for i in range(2,len(model_times)+2)])

ax1.plot(num_layers, model_times[:,1], 'b--')
line2, = ax1.plot(num_layers, [0]*len(num_layers), 'r-')
ax1.legend(['GNN-LRP naive', 'AMP-ave'])
line2.remove()
ax2.plot(num_layers, model_times[:,1], 'b--')
ax2.plot(num_layers, model_times[:,2], 'r-')
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.005)  # outliers only
ax2.set_ylim(-0,0.005)

ax1.ticklabel_format(axis='y', style='sci', scilimits=(0,0), useMathText=True)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_L.svg', dpi=600, format='svg', bbox_inches='tight')
# plt.show()

## compare with baselines

In [None]:
g = graphs[0]
pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=None, mode='gamma')
init_rel = np.zeros_like(H)
init_rel[:, pred] = H[:, pred]

time_accumulate_1 = 0
time_accumulate_2 = 0
iteration_num = 5
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    walk_rels = {}
    for walk in itertools.product(np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes),np.arange(g.nbnodes)):
        rel = walk_rel(transforms, init_rel, walk, mode="node")
        walk_rels[tuple(walk)] = rel
    
    time_tmp = time.time() - time_a

    time_a = time.time()
    top_walk = max(walk_rels.items(), key=lambda item: item[1])[0]
    time_accumulate_1 += time_tmp + time.time() - time_a

    time_a = time.time()
    sorted_walk_rels = sorted(walk_rels.items(), key=lambda item: item[1], reverse=True)
    time_accumulate_2 += time_tmp + time.time() - time_a

print(time_accumulate_1/iteration_num, time_accumulate_2/iteration_num)

In [None]:
time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()
    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=1, lrp_mode="gamma", 
            negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()
    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=100, lrp_mode="gamma", 
            negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()
    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=1000, lrp_mode="gamma", 
            negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)

In [None]:
time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()

    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=25, lrp_mode="", mode="node", transforms=transforms, H=init_rel)
    node_rel = {}
    for walk in top_max_walks:
        for node in walk[0]:
            if node not in node_rel: node_rel[node] = walk[1]
            else: node_rel[node] += walk[1]
        
        if len(node_rel) < 5: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1])]
        else: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1], reverse=True)[:5]]

    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()

    top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=25, lrp_mode="", mode="neuron", transforms=transforms, H=init_rel)
    node_rel = {}
    for walk in top_max_walks:
        for node in walk[0]:
            if node not in node_rel: node_rel[node] = walk[1]
            else: node_rel[node] += walk[1]
        
        if len(node_rel) < 5: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1])]
        else: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1], reverse=True)[:5]]

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)


In [None]:
time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    max_S = None
    max_rel = -float('inf')
    for S in itertools.combinations(range(25), 5):
        rel = subgraph_mp_transcription(nn, g, S, 0., H=H, transforms=transforms)
        if rel > max_rel:
            max_S = S
            max_rel = rel

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

In [None]:
subgraph_size = 5

time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    fo_gb = get_fo_gb(nn, g).tolist()
    subgraph_gb = fo_gb[:subgraph_size]

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    fo_cam = get_fo_cam(nn, g).tolist()
    subgraph_cam = fo_cam[:subgraph_size]

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    R = gnnexplainer(g, nn, g.node_features, verbose=False)
    fo_gnnexpl = get_fo_gnnexpl(R)
    subgraph_gnnexpl = fo_gnnexpl[:subgraph_size]

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

# Mutagenicity

In [None]:
dataset, model_dir = dataset_model_dirs[4]
graphs, pos_idx, neg_idx = load_data(dataset)
nn = torch.load('models/'+model_dir)

In [None]:
time_res = []
L = 3
K = 1

for g in tqdm(graphs):
    # g = graphs[0]
    pred = nn.forward(g.get_adj(),H0=g.node_features).argmax()
    H, transforms = get_H_transform(g.get_adj(),nn,H0=g.node_features,gammas=None, mode='gamma')
    init_rel = np.zeros_like(H)
    init_rel[:, pred] = H[:, pred]

    time_accumulate_1 = 0
    time_accumulate_2 = 0
    iteration_num = 1
    for _ in (range(iteration_num)):
        time_a = time.time()

        walk_rels = {}
        i = 0
        for walk in itertools.product(np.arange(g.nbnodes),repeat=L+1):
            if  i == 100: break
            rel = walk_rel(transforms, init_rel, walk, mode="node")
            walk_rels[tuple(walk)] = rel
            i += 1
        
        time_tmp = (time.time() - time_a) / 100 * (g.nbnodes ** (L+1))
        # print(time_tmp)

        time_a = time.time()
        sorted_walk_rels = sorted(walk_rels.items(), key=lambda item: item[1], reverse=True)[:K]
        time_accumulate_1 += time_tmp + time.time() - time_a

        time_a = time.time()
        top_max_walks, top_min_walks = topk_walks(g, nn, num_walks=K, lrp_mode="gamma", 
                negative_transition_strategy='none', mode="node", transforms=transforms, H=init_rel)
        time_accumulate_2 += time.time() - time_a

    time_res.append([g.nbnodes, time_accumulate_1/iteration_num, time_accumulate_2/iteration_num])

In [None]:
# plotting
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3.5,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

plt.rc('legend', fontsize=14.5) 
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax1.set_ylabel("Time (s)")
ax2.set_xlabel(r'$M$')

ax1.plot(time_res_M_mean.index.tolist(), time_res_M_mean['exhaustive'], 'b--')
line2, = ax1.plot(time_res_M_mean.index.tolist(), [0]*len(time_res_M_mean), 'r-')
line2.remove()
ax2.plot(time_res_M_mean.index.tolist(), time_res_M_mean['exhaustive'], 'b--')
ax2.plot(time_res_M_mean.index.tolist(), time_res_M_mean['ours'], 'r-')
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.51)  # outliers only
ax2.set_ylim(-0,0.5)
ax2.set_xlim(1,time_res_M_mean.index.max())

ax1.grid()
ax2.grid()

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_M.svg', dpi=600, format='svg', bbox_inches='tight')
# plt.show()

# Infection

In [None]:
from torch_geometric.utils import to_dense_adj
from infection_utils import get_rel_components, approx_most_rel_walk, plot_n_hop_infect_graph, walk_to_walk_edges, walk_to_walk_edges_set, approx_top_k_walk
import torch.nn.functional as F

In [None]:
with open("models/gcn-4-infection-sir_100_many_init_patient.torch",'rb') as f:
    model = pkl.load(f)
with open("datasets/infection-sir/dataset_100.pt",'rb') as f:
    dataset = pkl.load(f)

In [None]:
train_set = []
for d in dataset[:80]:
    train_set.append(d)
test_set = []
for d in dataset[80:]:
    test_set.append(d)

INIT_INFECT_RATE = 0.02
INFECTION_RATE = 0.6
CURE_RATE = 0.
IMMUNE_RATE = 1

SUSPECTFUL = 0
INFECTED = 1
IMMUNED = 2

zero_patient = False
np.random.seed(0)

steps = 4
for data in train_set + test_set:
    data.num_classes = 2
    del data.unique_solution_nodes
    del data.unique_solution_explanations

    A = to_dense_adj(data.edge_index)[0]
    
    if zero_patient:
        data.x = torch.zeros_like(data.x)
        data.x[:,1] = 1
        data.x[0,0] = 1
        data.x[0,1] = 0
    else:
        init_infect_rate_ = np.random.rand() * INIT_INFECT_RATE
        x = torch.tensor([int(np.random.rand() < init_infect_rate_) for _ in range(data.num_nodes)])
        data.x = torch.column_stack([x, 1*(x != 1)]).float()
    
    x_state = torch.zeros(data.num_nodes)
    x_state[torch.where(data.x[:,0]==1)] = INFECTED

    infection_chains = {}
    for node in range(data.num_nodes):
        if x_state[node] == INFECTED:
            infection_chains[node] = [node]

    for step in range(steps):
        I_nodes = torch.where(x_state==INFECTED)[0].tolist()
        for I_node in I_nodes:
            for node in A[I_node].nonzero().flatten().tolist():
                if node in I_nodes:
                    continue

                if np.random.rand() < INFECTION_RATE:
                    if x_state[node] == IMMUNED:
                        if np.random.rand() < 1 - IMMUNE_RATE:
                            x_state[node] = INFECTED
                            infection_chains[node] = infection_chains[I_node].copy() + [node]
                    else:
                        x_state[node] = INFECTED
                        infection_chains[node] = infection_chains[I_node].copy() + [node]
            
            if np.random.rand() < CURE_RATE:
                x_state[I_node] = IMMUNED
                del infection_chains[I_node]
        
    data.infection_chains = infection_chains.copy()
    data.y = (x_state==INFECTED) * 1

In [None]:
data = dataset[0]

In [None]:
time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = torch.linspace(4, 1, 4), normalize=False)
    real_top_k_max_walks_rels, _ = approx_top_k_walk(1, 7, data, model, rel_components, X_fc, A, verbose=False)

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = torch.linspace(4, 1, 4), normalize=False)
    real_top_k_max_walks_rels, _ = approx_top_k_walk(100, 7, data, model, rel_components, X_fc, A, verbose=False)

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

time_accumulate = 0
for _ in tqdm(range(iteration_num)):
    time_a = time.time()

    rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = torch.linspace(4, 1, 4), normalize=False)
    real_top_k_max_walks_rels, _ = approx_top_k_walk(1000, 7, data, model, rel_components, X_fc, A, verbose=False)

    time_accumulate += time.time() - time_a
    
print(time_accumulate / iteration_num)

In [None]:
node_idx_to_explain = 7
target = data.y[node_idx_to_explain]
init_rel = X_fc * (model.fc.state_dict()['weight'].T[:,target])
init_rel[:node_idx_to_explain] = 0
init_rel[min(node_idx_to_explain+1, init_rel.shape[0]):] = 0
init_rel = init_rel/init_rel.sum()

In [None]:
def walk_rel_infection(walk, init_rel, rel_components, A):
    A_diag = torch.eye(A.shape[0])
    R_Ii = init_rel[walk[-1]]
    for i in range(1, 1+len(rel_components)):
        W, U, X, denom = rel_components[-i]
        J = walk[-i]
        I = walk[-1-i]
        R_denom_J = R_Ii / denom[J] # shape [j]
        
        R_Ii = X[I] * (A[I,J] * R_denom_J @ W.T + A_diag[I,J] * R_denom_J @ U.T) # shape [i]
    rel = R_Ii.sum()

    return rel

In [None]:
time_accumulate_1 = 0
time_accumulate_2 = 0
iteration_num = 5
walk_rels = {}

A = to_dense_adj(data.edge_index)[0]
l1 = set(A[:, node_idx_to_explain].nonzero().flatten().tolist())
l1.add(node_idx_to_explain)
l2 = set(A[:, list(l1)].sum(axis=1).nonzero().flatten().tolist())
l2.add(node_idx_to_explain)
l3 = set(A[:, list(l2)].sum(axis=1).nonzero().flatten().tolist())
l3.add(node_idx_to_explain)
l4 = set(A[:, list(l3)].sum(axis=1).nonzero().flatten().tolist())
l4.add(node_idx_to_explain)

print(len(l4)*len(l3)*len(l2)*len(l1), 'possible walks')

for _ in range(iteration_num):
    time_a = time.time()

    walk_rels = {}
    for walk in tqdm(itertools.product(l4,l3,l2,l1)):
        walk = list(walk) + [node_idx_to_explain]
        rel = walk_rel_infection(walk, init_rel, rel_components, A)
        walk_rels[tuple(walk)] = rel
    
    time_tmp = time.time() - time_a

    time_a = time.time()
    top_walk = max(walk_rels.items(), key=lambda item: item[1])[0]
    time_accumulate_1 += time_tmp + time.time() - time_a

    time_a = time.time()
    sorted_walk_rels = sorted(walk_rels.items(), key=lambda item: item[1], reverse=True)
    time_accumulate_2 += time_tmp + time.time() - time_a

In [None]:
time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()

    rel_components, X_fc, A = get_rel_components(data=data, model=model,lrp_rule = 'gamma', gammas = torch.linspace(4, 1, 4), normalize=False)
    real_top_k_max_walks_rels, _ = approx_top_k_walk(25, node_idx_to_explain, data, model, rel_components, X_fc, A, verbose=False)
    
    node_rel = {}
    for walk in real_top_k_max_walks_rels:
        for node in walk[0]:
            if node not in node_rel: node_rel[node] = walk[1]
            else: node_rel[node] += walk[1]
        
        if len(node_rel) < 5: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1])]
        else: S = [item[0] for item in sorted(node_rel.items(), key=lambda x: x[1], reverse=True)[:5]]

    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)

In [None]:
# GNNExplainer
def sigm(z):
    return torch.tanh(0.5*z)*0.5+0.5
x_init, edge_index, edge_weight, y = data.x, data.edge_index, None, data.y
pred = model.forward(edge_index=edge_index, x=data.x)[node_idx_to_explain].argmax()
A = to_dense_adj(edge_index)[0]
A_diag = torch.eye(A.shape[0])

steps=500
lr=0.5
lambd=0.01
verbose=False

time_accumulate = 0
for _ in range(iteration_num):
    time_a = time.time()
    
    z = torch.ones(A.shape)*A*2
    num_layer = 4
    bar = tqdm(range(steps)) if verbose else range(steps)
    for i in bar:
        z.requires_grad_(True)
        x = x_init

        for i in range(len(model.convs)):
            state_dict = model.convs[i].state_dict()
            
            W = state_dict['lin_rel.weight'].T 
            U = state_dict['lin_root.weight'].T 
            
            out = (sigm(z) * A).T @ x
            out = out @ W + state_dict['lin_rel.bias'] + x @ U
            
            x = F.relu(out)

        x = x @ model.fc.state_dict()['weight'].T + model.fc.state_dict()['bias']
        score = x[node_idx_to_explain][pred]

        emp   = -score
        reg   = lambd*((z)**2).sum() # torch.zeros((1,))   

        if i in [j**3 for j in range(100)] and verbose: print('%5d %8.3f %8.3f'%(i,emp.item(),reg.item()))
        
        (emp+reg).backward()

        with torch.no_grad():
            z = (z - lr*z.grad)
        z.grad = None

    S = get_fo_gnnexpl(z.data, topk = 25)[:5]

    time_accumulate += time.time() - time_a

print(time_accumulate / iteration_num)