In [1]:
import warnings
warnings.filterwarnings("ignore")
import time
from tqdm.notebook import tqdm
import numpy as np

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse, add_self_loops, remove_self_loops, k_hop_subgraph
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from utils import *
from model import *
from explainer import *
from main_distiller import *

torch.manual_seed(123)
torch.random.manual_seed(123)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# DnX and FastDnX 


Example explaination:

Folder description:

  * trained_gcn: contains the models that will be explained
  * trained_distiller: contains the distilled models
  * explanation: contains the explanations generated by our models

Datasets: 

The nomenclature (syn1, syn2 ...., syn6) is used to facilitate experiments

  * BA-HouseShapes (syn1), 
  * BA-Community (syn2), 
  * BA-Grids (syn3), 
  * Tree-Cycles (syn4), 
  * TreeGrids (syn5), 
  * BA-Bottle-Shaped (syn6) 

To run the DnX and Fast DnX explainers for the synthetic datasets just run the notebook `main_syn.ipynb`:



In [2]:
dataset = 'syn1'

In [3]:
dataset_path = 'datasets/'
out_path = 'explanation/'
distiller_path = 'trained_distiller/'
model_path = 'trained_gcn/'

## Loading dataset

In [4]:
# Load dataset and model to be explainer 

A_np, X_np = load_XA(dataset, datadir = dataset_path)
num_nodes = X_np.shape[0]
labels = load_labels(dataset, datadir = dataset_path)
num_class = max(labels) + 1

ckpt = load_ckpt(dataset,datadir = model_path)
layer = 3

A = torch.tensor(A_np, dtype=torch.float32).to(device)
X = torch.tensor(X_np, dtype=torch.float32).to(device)

if dataset =='syn2':
    X = torch.concat([F.one_hot(torch.sum(A,1).type(torch.LongTensor)).type(torch.float32), X],1)
else: 
    X = F.one_hot(torch.sum(A,1).type(torch.LongTensor)).type(torch.float32).to(device)

input_dim = X.shape[1]

edge_index,_ = dense_to_sparse(A)
edge_index = edge_index.to(device)

pred = ckpt["save_data"]["pred"].squeeze(0)
L_model = torch.softmax(torch.tensor(pred),1)


node_list, k = get_nodes_explained(dataset, A_np)

trained_gcn/syn1.pth.tar
=> loading checkpoint 'trained_gcn/syn1.pth.tar'


## Generating the distilled model

In [5]:
main(dataset, (X,  A, edge_index, L_model ), 1000, 1, layer, input_dim,num_class, distiller_path)

  0%|          | 0/999 [00:00<?, ?it/s]

Mean Accuracy:0.9267857142857143 |  Std Accuracy:0.0
trained_distiller/SGC_syn1.pth.tar


In [6]:
ckpt_distillation = load_ckpt('SGC_'+dataset,datadir = distiller_path)

model = SGC(3, input_dim, num_class).to(device)
model.load_state_dict(ckpt_distillation["model_state"])
model.eval()
pred_model = model(X,edge_index).to(device)

trained_distiller/SGC_syn1.pth.tar
=> loading checkpoint 'trained_distiller/SGC_syn1.pth.tar'


# Generating explanations for the model

## DnX

In [8]:
nodes_explanations = {}
results = {}

t = 5
itr_no = 0

print("Explaining {} dataset ".format(dataset))

for node in tqdm(node_list):
    nodes_explanations_aux = {}
    acc_top = 0

    neighbors, sub_edge_index, node_idx_new, _ = k_hop_subgraph(int(node), layer, edge_index, relabel_nodes=True)
    sub_X = X[neighbors]

    node_idx_new, sub_edge_index, sub_X, neighbors = node_idx_new.to(device), sub_edge_index.to(device), sub_X.to(
        device), neighbors.to(device)
    
    explainer = DnX(len(neighbors), node_idx_new).to(device)
    opt = torch.optim.Adam(explainer.parameters(), lr=0.1)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min',
                                                           factor=0.5, min_lr=1e-5,
                                                           patience=20,
                                                           verbose=True)
    loss = nn.MSELoss()

    for epoch in range(100):

        explainer.train()
        opt.zero_grad()

        expl, pred_ex = explainer(sub_X, sub_edge_index, model, t)

        l = loss(pred_ex[node_idx_new], pred_model[node])

        l.backward(retain_graph=True)

        opt.step()
        scheduler.step(l)

        nodes_explanations_aux[node] = generating_explanations(node, model, explainer, edge_index, X, t, k)[node]
        acc, prec = evaluate_syn_explanation(nodes_explanations_aux, dataset)
        if acc > acc_top: 
            acc_top = acc
            exp_top = nodes_explanations_aux[node]

        if acc == 1.0:
            break

    nodes_explanations[node] = exp_top
    acc, prec = evaluate_syn_explanation(nodes_explanations, dataset)


with open(out_path+'nodes_explanations_DnX_{}.txt'.format(dataset), 'w') as f:
    f.write("%s\n" % nodes_explanations)

print("Accuracy: ", acc)
print("Precision: ", prec)

Explaining syn1 dataset 


  0%|          | 0/400 [00:00<?, ?it/s]

Accuracy:  0.975
Precision:  0.975


## FastDnX

In [9]:
print("Explaining {} dataset ".format(dataset))

explanations = {}

W = ckpt_distillation['model_state']['conv.lin.weight'].to(device)
bias = (ckpt_distillation['model_state']['conv.lin.bias']).to(device)
A_pot = A_k_hop(A, layer).to(device)

k_nodes = k

for no_alvo in tqdm(node_list):
        nodes_neigh, _, node_ex, _ = k_hop_subgraph(int(no_alvo), layer, edge_index)

        S = (X[nodes_neigh].T * A_pot[no_alvo, nodes_neigh]).T       
        pred = torch.matmul(S, W.T)   
        L = torch.ones(len(nodes_neigh), len(pred_model[no_alvo])).to(device) * pred_model[no_alvo] - bias#pred_model[no_alvo] - bias 
        expl = torch.diag(torch.matmul(L, pred.T) )

        if len(nodes_neigh)<k:
            k_nodes= len(nodes_neigh)
        values, nodes = torch.topk(expl, dim=0,k=k_nodes)
        k_nodes = k

        explanations[no_alvo] = nodes_neigh[nodes].tolist()
    
acc, prec = evaluate_syn_explanation(explanations,dataset)

print("Accuracy: ", acc)
print("Precision: ", prec)

with open(out_path+'nodes_explanations_fastDnX_{}.txt'.format(dataset), 'w') as f:
    f.write("%s\n" % explanations)


Explaining syn1 dataset 


  0%|          | 0/400 [00:00<?, ?it/s]

Accuracy:  0.995
Precision:  0.995
