In [8]:
import torch
import time
from utils import load_dataset, load_model, plot_graph
from models import SGCNet, GCN
from explainer import FastDnX
import numpy as np
from torch_geometric.utils import k_hop_subgraph
from evaluation import get_ground_truth, edge_to_node, get_ground_truth_edge, evaluation_auc_node, fidelity_neg
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_name = 'syn1'

data = load_dataset(dataset_name)
input_dim = data.x.shape[1]
num_classes = num_classes = max(data.y) + 1

ckpt = torch.load('./checkpoints/GCN_'+dataset_name+'.pth.tar', map_location=torch.device(device))
x = torch.ones((700, 10)) #syn1 original features
model_gcn = GCN(num_features=x.shape[1], num_classes=num_classes.item())
model_gcn.load_state_dict(ckpt["model_state"]) 
model_gcn = model_gcn.to(device)

labels = model_gcn(x, data.edge_index)

In [3]:
if dataset_name == 'syn1':
    node_list = list(range(300,700))
    motif_size=5

## GNNexplainer

In [4]:
import time
from torch_geometric.nn import GNNExplainer 
hops = 3
explai = GNNExplainer(model_gcn, num_hops=hops)
inicio = time.time()
explanations = {}
for idx in tqdm(node_list):
    neigh_edge_index =  k_hop_subgraph(int(idx), hops, data.edge_index)[1]
    edge_idx = k_hop_subgraph(int(idx), hops, data.edge_index)[-1].to(device)
    a, b = explai.explain_node(int(idx), x, data.edge_index)
    explanations[idx] = (neigh_edge_index.to('cpu').numpy(), b[edge_idx].numpy())
    
fim = time.time()
print(fim-inicio)
np.save('./explanations/gnnexplainer_'+dataset_name+'_gcn.npy', np.array([explanations]))

100%|█████████████████████████████████████████| 400/400 [04:27<00:00,  1.49it/s]

267.8792634010315





In [5]:
explanations = np.load('./explanations/gnnexplainer_'+dataset_name+'_gcn.npy', allow_pickle=True)[0]

import torch.nn as nn
m = nn.Softmax(dim=1)

## edge level auc

In [6]:
all_expl_nodes = []
for no_alvo in tqdm(np.array(node_list)):
    values = explanations[no_alvo][1]
    edges = explanations[no_alvo][0]

    all_expl_nodes.append((edges, values))
    
explanation_labels = get_ground_truth_edge(data.edge_index, data.y, dataset_name)
auc_score = evaluation_auc_node(all_expl_nodes, explanation_labels)
print(f'auc edge level: {auc_score}')

100%|█████████████████████████████████████| 400/400 [00:00<00:00, 423346.35it/s]


auc edge level: 0.6692554604838357


## node level acc

In [10]:
accs = []
for no_alvo in tqdm(node_list):
    alledges = explanations[no_alvo][0]
    allexpls = explanations[no_alvo][1]

    if alledges.shape[1] < 1:
        continue

    nodes, expls = edge_to_node(allexpls, alledges)
    
    
    if len(nodes) > motif_size:
        value_expl, idx_expl = torch.topk(torch.tensor(expls), dim=0,k=motif_size)
        node_expl = nodes[idx_expl]
    else:
        node_expl = nodes
        value_expl = expls
    real = np.array(get_ground_truth(no_alvo, dataset_name))
    acc = len(list(filter(lambda x: x in real, node_expl))) / len(node_expl)
    accs.append(acc)

print(f'accuracy node level: {np.mean(accs)}')

100%|████████████████████████████████████████| 400/400 [00:00<00:00, 894.17it/s]

accuracy node level: 0.701





## fidelity-

In [12]:
all_expl_nodes = {}
for no_alvo in tqdm(node_list):
    alledges = explanations[no_alvo][0]
    allexpls = explanations[no_alvo][1]

    if alledges.shape[1] < 1:
        continue

    nodes, expls = edge_to_node(allexpls, alledges)
    all_expl_nodes[no_alvo] = [nodes, torch.tensor(expls)]
    
fidelity_neg(model_gcn, node_list, data.edge_index, motif_size, x, data.y, all_expl_nodes)

100%|████████████████████████████████████████| 400/400 [00:00<00:00, 953.82it/s]
100%|████████████████████████████████████████| 400/400 [00:01<00:00, 203.13it/s]


fidelidade-: 0.035



