In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torch_geometric.nn import GNNExplainer
from tqdm import tqdm



class GNNExplainerUpdated(GNNExplainer):
    """ 
    https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/models/gnn_explainer.py 
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def explain_graph(self, x, edge_index, **kwargs):
        r"""Learns and returns a node feature mask and an edge mask that play a
        crucial role to explain the prediction made by the GNN for a graph.
        Args:
            x (Tensor): The node feature matrix.
            edge_index (LongTensor): The edge indices.
            **kwargs (optional): Additional arguments passed to the GNN module.
        :rtype: (:class:`Tensor`, :class:`Tensor`)
        """

        self.model.eval()
        self.__clear_masks__()

        # all nodes belong to same graph
        batch = torch.zeros(x.shape[0], dtype=int, device=x.device)

        # Get the initial prediction.
        with torch.no_grad():
            out = self.model(x=x, edge_index=edge_index, batch=batch, **kwargs)
            log_logits = self.__to_log_prob__(out)
            pred_label = log_logits.argmax(dim=-1)

        self.__set_masks__(x, edge_index)
        self.to(x.device)

        optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                     lr=self.lr)

        if self.log:  # pragma: no cover
            pbar = tqdm(total=self.epochs)
            pbar.set_description('Explain graph')

        for epoch in range(1, self.epochs + 1):
            optimizer.zero_grad()
            h = x * self.node_feat_mask.view(1, -1).sigmoid()
            out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
            log_logits = self.__to_log_prob__(out)
            loss = self.__loss__(-1, log_logits, pred_label)
            loss.backward()
            optimizer.step()

            if self.log:  # pragma: no cover
                pbar.update(1)

        if self.log:  # pragma: no cover
            pbar.close()

        node_feat_mask = self.node_feat_mask.detach().sigmoid()
        edge_mask = self.edge_mask.detach().sigmoid()

        self.__clear_masks__()
        return node_feat_mask, edge_mask

In [None]:
from models import DeeperGCN

model = DeeperGCN(5, 256, device)
model.load_state_dict(torch.load('gnn.pt'))
model.eval()

In [None]:
from graph_data import GraphData

graph_data = GraphData()

In [None]:
random_sample = graph_data.return_random_graph().to(device)

explainer = GNNExplainerUpdated(model, epochs=100, return_type='log_prob')
node_feat_mask, edge_mask = explainer.explain_graph(
    random_sample.x, random_sample.edge_index
)

In [None]:
import matplotlib.pyplot as plt
ax, G = explainer.visualize_subgraph(-1, random_sample.edge_index, edge_mask)

plt.rcParams["figure.figsize"] = (40,30)

plt.show()

In [None]:
print(node_feat_mask.shape, edge_mask.shape)
print(node_feat_mask, edge_mask)
random_sample

In [None]:
# for edge in edge_index.T:
# for edge, relevance in zip(edge_index.T, edge_mask):
# node_relevance[edge[0]] += relevance
# node_relevance[edge[1]] += relevance
# node_count[edge[0]] += 1
# node_count[edge[1]] += 1
# node_relevance =/ node_count
# meshio