## import libraries

In [1]:
import torch
import time
from utils import load_dataset, load_model, plot_graph
from models import SGCNet, GCN
from explainer import FastDnX
import numpy as np
from torch_geometric.utils import k_hop_subgraph
from evaluation import get_ground_truth, node_to_edge, get_ground_truth_edge, evaluation_auc_node, fidelity_neg
from tqdm import tqdm

## load dataset and the distilated model

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_name = 'syn1'

data = load_dataset(dataset_name)
input_dim = data.x.shape[1]
num_classes = num_classes = max(data.y) + 1
model = SGCNet(input_dim, num_classes.item(), 3)
model = load_model(model, 'SGC', dataset_name, device)
model.eval()
labels = model(data.x, data.edge_index)

node_list = list(range(300,700)) # for syn1
motif_size = 5 # for syn1

## set the explanation model and explain the distilated model

In [3]:
explainer = FastDnX(model, data.x, 'node', 3, data.edge_index)
explainer.prepare()

inicio = time.time()
explanations = {}

for no_alvo in np.array(node_list):
    nodes, values  = explainer.explain(int(no_alvo))
    explanations[no_alvo] = [nodes, values]

fim = time.time()
print(f'Time: {fim-inicio}')

np.save('./explanations/fastdnx_'+dataset_name+'_gcn.npy', np.array([explanations]))

finding top nodes...
Time: 0.16661667823791504


In [4]:
explanations = np.load('./explanations/fastdnx_'+dataset_name+'_gcn.npy', allow_pickle=True)[0]

import torch.nn as nn
m = nn.Softmax(dim=1)

## evaluate synthetic - node livel

In [5]:
accs = []
for no_alvo in tqdm(np.array(node_list)):
    nodes, expl = explanations[no_alvo]
    if len(nodes) > motif_size:
        value_expl, idx_expl = torch.topk(expl, dim=0,k=motif_size)
        node_expl = nodes[idx_expl]
    else:
        node_expl = nodes
        value_expl = expl
    if m(labels).argmax(dim=1)[no_alvo] == data.y[no_alvo]:
        real = np.array(get_ground_truth(no_alvo, dataset_name))
        acc = len(list(filter(lambda x: x in real, node_expl.numpy()))) / len(node_expl)
        accs.append(acc)

print(f'accuracy node level: {np.mean(accs)}')

100%|███████████████████████████████████████| 400/400 [00:00<00:00, 5538.48it/s]

accuracy node level: 0.9982954545454544





## evaluate synthetic - edge livel

In [10]:
explanation_labels

(array([[  0,   0,   0, ..., 698, 699, 699],
        [  5,   6,   8, ..., 697, 695, 696]]),
 array([0, 0, 0, ..., 1, 1, 1]))

In [6]:
all_expl_nodes = []
for no_alvo in tqdm(np.array(node_list)):
    values = explanations[no_alvo][1].detach().numpy()
    nodes = explanations[no_alvo][0].numpy()
    edges_neigh = k_hop_subgraph(int(no_alvo), 3, data.edge_index)[1]

    all_expl_nodes.append(node_to_edge(values, nodes, edges_neigh))
    
explanation_labels = get_ground_truth_edge(data.edge_index, data.y, dataset_name)
auc_score = evaluation_auc_node(all_expl_nodes, explanation_labels)
print(f'auc edge level: {auc_score}')

100%|████████████████████████████████████████| 400/400 [00:00<00:00, 681.21it/s]


auc edge level: 0.9973833549713749


## fidelity-

<img src="./imgs/fidelity.png">  

where $ŷ$ is the original GNN predicition, $ŷ^{mi}$ is the GNN prediction only with explanations, and $y$ is the original label.

In [7]:
ckpt = torch.load('./checkpoints/GCN_'+dataset_name+'.pth.tar', map_location=torch.device(device))

x = torch.ones((700, 10)) #syn1 original features
model_gcn = GCN(num_features=x.shape[1], num_classes=num_classes.item())
model_gcn.load_state_dict(ckpt["model_state"]) 
model_gcn = model_gcn.to(device)
fidelity_neg(model_gcn, node_list, data.edge_index, motif_size, x, data.y, explanations)


100%|████████████████████████████████████████| 400/400 [00:02<00:00, 151.44it/s]


fidelidade-: 0.035



