In [1]:
import numpy as np
from scipy.special import softmax
import torch
from torch.utils.data import DataLoader
import networkx as nx
from torchvision import datasets as ds
from torchvision import transforms
import argparse

from nets.superpixels_graph_classification.load_net import gnn_model # import all GNNS
from data.data import LoadData # import dataset
from utils import GCN_params

In [2]:
MNIST_test_dataset = ds.MNIST(root='PATH', train=False, download=True, transform=transforms.ToTensor())
MODEL_NAME = 'GCN'
DATASET_NAME = 'MNIST'
dataset = LoadData(DATASET_NAME)
trainset, valset, testset = dataset.train, dataset.val, dataset.test

net_params = GCN_params.net_params()
model = gnn_model(MODEL_NAME, net_params)
model.load_state_dict(torch.load("data/superpixels/epoch_188.pkl"))
model.eval()

test_loader = DataLoader(testset, batch_size=1, shuffle=False, drop_last=False, collate_fn=dataset.collate)

[I] Loading dataset MNIST...
train, test, val sizes : 55000 10000 5000
[I] Finished loading.
[I] Data load time: 48.0152s


In [3]:
from gnnExplainer import explain

In [4]:
index_to_explain = [0]
Explanations = []
for iter, (graph, label, snorm_n, snorm_e) in enumerate(test_loader):
    if iter in index_to_explain:
        g = graph
        l = label
        pred = model.forward(g, 
                             g.ndata['feat'],
                             g.edata['feat'],
                             snorm_n, 
                             snorm_e)
        adj = g.adjacency_matrix().to_dense()
        feat = g.ndata['feat']
        norm_n = snorm_n
        norm_e = snorm_e
        
            
        



In [5]:
adj_ = np.expand_dims(adj, axis=0)
feat_ = np.expand_dims(feat, axis=0)

In [6]:
adj_torch = torch.tensor(adj_, dtype=torch.float)
x_torch = torch.tensor(feat_, requires_grad=True, dtype=torch.float)
l_torch = l.clone().detach()
pred_label = np.argmax(pred[0].detach().numpy(), axis=0)

In [86]:
explainer = ExplainModule(g, model, l_torch,
                         x_torch, g.edata['feat'],
                         norm_n, norm_e)

In [87]:
explainer.forward()

tensor([6.0067e-12, 4.6651e-08, 9.9997e-01, 1.2170e-06, 1.7609e-10, 1.3807e-08,
        3.0896e-05, 3.8291e-10, 3.7713e-13, 2.7728e-14],
       grad_fn=<SoftmaxBackward>)

In [12]:
import importlib
importlib.reload(train_utils)

<module 'gnnExplainer.train_utils' from '/home/minhvu/GCN/NIPS2020/PGM_Graph/gnnExplainer/train_utils.py'>

In [85]:
import torch.nn as nn
import math

from gnnExplainer import train_utils

class ExplainModule(nn.Module):
    def __init__(
        self,
        graph,
        model,
        label,
        nfeat,
        efeat,
        snorm_n,
        snorm_e,
        use_sigmoid=True,
    ):
        super(ExplainModule, self).__init__()
        self.graph = graph
        self.model = model
        self.label = label
        self.nfeat = nfeat
        self.efeat = efeat
        self.snorm_n = snorm_n
        self.snorm_e = snorm_e
        self.mask_act = "sigmoid"
        self.use_sigmoid = use_sigmoid

        num_nodes = graph.number_of_nodes()
        
        self.mask = self.construct_edge_mask(num_nodes, init_strategy="const")
        self.feat_mask = self.construct_feat_mask(nfeat.size(-1), init_strategy="constant")
        params = [self.mask, self.feat_mask]

        self.diag_mask = torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes)
        self.scheduler, self.optimizer = train_utils.build_optimizer_(params)

        self.coeffs = {
            "size": 0.005,
            "feat_size": 1.0,
            "ent": 1.0,
            "feat_ent": 0.1,
            "grad": 0,
            "lap": 1.0,
        }

    def construct_feat_mask(self, feat_dim, init_strategy="normal"):
        mask = nn.Parameter(torch.FloatTensor(feat_dim))
        if init_strategy == "normal":
            std = 0.1
            with torch.no_grad():
                mask.normal_(1.0, std)
        elif init_strategy == "constant":
            with torch.no_grad():
                nn.init.constant_(mask, 0.0)
                # mask[0] = 2
        return mask

    def construct_edge_mask(self, num_nodes, init_strategy="normal", const_val=1.0):
        mask = nn.Parameter(torch.FloatTensor(num_nodes, num_nodes))
        if init_strategy == "normal":
            std = nn.init.calculate_gain("relu") * math.sqrt(
                2.0 / (num_nodes + num_nodes)
            )
            with torch.no_grad():
                mask.normal_(1.0, std)
        elif init_strategy == "const":
            nn.init.constant_(mask, const_val)

        return mask
    
    def _masked_adj(self):
        sym_mask = self.mask
        if self.mask_act == "sigmoid":
            sym_mask = torch.sigmoid(self.mask)
        elif self.mask_act == "ReLU":
            sym_mask = nn.ReLU()(self.mask)
        sym_mask = (sym_mask + sym_mask.t()) / 2
        adj = self.graph.adjacency_matrix(transpose = False).to_dense()
        masked_adj = adj * sym_mask

        return masked_adj * self.diag_mask
         
    def mask_density(self):
        mask_sum = torch.sum(self._masked_adj()).cpu()
        adj_sum = torch.sum(self.adj)
        return mask_sum / adj_sum

    def forward(self, unconstrained=False, mask_features=True, marginalize=False):
        x = self.nfeat[0]
        
        if unconstrained:
            sym_mask = torch.sigmoid(self.mask) if self.use_sigmoid else self.mask
            self.masked_adj = (
                torch.unsqueeze((sym_mask + sym_mask.t()) / 2, 0) * self.diag_mask
            )
        else:
            self.masked_adj = self._masked_adj()
            if mask_features:
                feat_mask = (
                    torch.sigmoid(self.feat_mask)
                    if self.use_sigmoid
                    else self.feat_mask
                )
                if marginalize:
                    std_tensor = torch.ones_like(x, dtype=torch.float) / 2
                    mean_tensor = torch.zeros_like(x, dtype=torch.float) - x
                    z = torch.normal(mean=mean_tensor, std=std_tensor)
                    x = x + z * (1 - feat_mask)
                else:
                    x = x * feat_mask
        
        pred = self.model.forward(g, 
                             x,
                             self.efeat,
                             self.snorm_n, 
                             self.snorm_e)
    
        res = nn.Softmax(dim=0)(pred[0])
        
        return res

In [5]:
top_node = 3

In [7]:
index_to_explain = range(10)
Explanations = []
for iter, (graph, label, snorm_n, snorm_e) in enumerate(test_loader):
    if iter in index_to_explain:
        pred = model.forward(graph, graph.ndata['feat'],graph.edata['feat'],snorm_n, snorm_e)
        soft_pred = np.asarray(softmax(np.asarray(pred[0].data)))
        pred_threshold = 0.1*np.max(soft_pred)
        e = pe.Graph_Explainer(model, graph, 
                               snorm_n = snorm_n, snorm_e = snorm_n, 
                               perturb_feature_list = [0],
                               perturb_mode = "mean",
                               perturb_indicator = "abs")
        pgm_nodes, p_values, candidates = e.explain(num_samples = 400, percentage = 20, 
                                top_node = 4, p_threshold = 0.05, pred_threshold = pred_threshold)
        label = np.argmax(soft_pred)
        pgm_nodes_filter = [i for i in pgm_nodes if p_values[i] < 0.02]
        x_cor = [e.X_feat[node_][1] for node_ in pgm_nodes_filter]
        y_cor = [e.X_feat[node_][2] for node_ in pgm_nodes_filter]
        result = [iter, label, pgm_nodes_filter, x_cor, y_cor]
        print(result)
        Explanations.append(result)
#         savedir = 'result/explanations_'+ str(prog_args.start) + "_" + str(prog_args.end) +".txt"
#         with open(savedir, "a") as text_file:
#             text_file.write(str(result) + "\n")
            

[0, 7, [57, 73, 24], [0.5161133, 0.33935547, 0.78564453], [0.6333008, 0.67626953, 0.91064453]]
[1, 2, [69, 38], [0.7182617, 0.7548828], [0.34594727, 0.5229492]]
[2, 1, [44, 5, 13], [0.4855957, 0.05355835, 0.625], [0.51171875, 0.33935547, 0.91064453]]
[3, 0, [19, 51, 7], [0.4465332, 0.14282227, 0.46435547], [0.31420898, 0.14282227, 0.4477539]]
[4, 4, [68, 42, 35], [0.46435547, 0.22851562, 0.47094727], [0.6430664, 0.6855469, 0.28125]]
[5, 1, [61], [0.2529297], [0.6010742]]
[6, 4, [40, 50], [0.4741211, 0.4675293], [0.5654297, 0.3310547]]
[7, 9, [5, 16, 48, 3], [0.23571777, 0.14282227, 0.035705566, 0.78564453], [0.43579102, 0.91064453, 0.91064453, 0.25]]
[8, 5, [62, 60], [0.7319336, 0.2142334], [0.6069336, 0.80371094]]
[9, 9, [29, 40], [0.14282227, 0.29296875], [0.25, 0.54296875]]


In [99]:
type(g.ndata['feat'][0,0].item())

float