# GNNExplainer Tutorial

This tutorial demonstrates how to use `GNNExplainer` from the Captum library to explain predictions made by a Graph Neural Network (GNN). GNNExplainer identifies a compact subgraph structure and a small subset of node features that are most influential for a GNN's prediction.

**Reference:** [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894)

## 1. Setup

First, let's install and import the necessary libraries. We'll need `torch`, `torch_geometric` for graph data and GNN layers, `captum` for GNNExplainer, and `networkx` / `matplotlib` for visualization.

In [None]:
!pip install -q torch torchvision torchaudio
!pip install -q torch_geometric
!pip install -q captum
!pip install -q networkx matplotlib

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import KarateClub
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

# Assuming GNNExplainer is in this path relative to the notebook or installed in the env
# For a real Captum integration, it would be: from captum.attr import GNNExplainer
from captum.attr._gnn.gnn_explainer import GNNExplainer # Adjust if necessary

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Data Loading

We'll use the classic Zachary's Karate Club dataset from `torch_geometric.datasets`.

In [None]:
dataset = KarateClub()
data = dataset[0].to(device)
print(f"Dataset: {dataset.name}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Number of features: {data.num_node_features}")
print(f"Number of classes: {dataset.num_classes}")

## 3. Model Definition

Let's define a simple Graph Convolutional Network (GCN) model for node classification.

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index, edge_weight=None):
        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight=edge_weight) # Pass edge_weight to second layer too if desired
        return F.log_softmax(x, dim=1)

model = GCN(dataset.num_node_features, 16, dataset.num_classes).to(device)
print(model)

### Model Training
We need to train the model to get meaningful explanations. GNNExplainer works by finding important graph structures for the *current* model's predictions.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

model.eval()
_, pred_labels = model(data.x, data.edge_index).max(dim=1)
correct_nodes = (pred_labels[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = int(correct_nodes) / int(data.test_mask.sum())
print(f'Test Accuracy: {accuracy:.4f}')

## 4. Attribution with GNNExplainer

Now, let's use GNNExplainer to understand the prediction for a specific node.

In [None]:
# Instantiate GNNExplainer
explainer = GNNExplainer(model)

# Choose a target node to explain
target_node_idx = 0 
print(f"Explaining node: {target_node_idx}")
print(f"Node true label: {data.y[target_node_idx].item()}")
print(f"Node predicted label: {pred_labels[target_node_idx].item()}")

# Get attributions
# GNNExplainer's attribute method might require target_class if not taking argmax internally
# For node classification, we usually explain the predicted class or the true class.
target_class = pred_labels[target_node_idx].item()

node_feat_mask, edge_mask = explainer.attribute(
    inputs=data.x, 
    edge_index=data.edge_index, 
    target_node=target_node_idx,
    target_class=target_class, # Specify class if GNNExplainer needs it
    num_epochs=150, # Number of epochs to train the masks
    lr=0.01 # Learning rate for mask optimization
)

print("\nNode Feature Mask (first 5 features):")
print(node_feat_mask[:5])
print("\nEdge Mask (first 5 edges):")
print(edge_mask[:5])

The `node_feat_mask` tells us the importance of each input feature for the specified node's prediction. The `edge_mask` indicates the importance of each edge.

## 5. Visualization

Let's visualize the explanation. We can highlight the important edges based on the `edge_mask`.

In [None]:
def plot_graph_with_masks(edge_index, edge_mask, target_node_idx, title, node_labels=None, threshold=0.5):
    num_nodes = edge_index.max().item() + 1
    g_nx = nx.Graph()
    g_nx.add_nodes_from(range(num_nodes))
    
    # Add edges with weights from edge_mask
    for i, (u, v) in enumerate(edge_index.t().tolist()):
        g_nx.add_edge(u, v, weight=edge_mask[i].item())
        
    pos = nx.spring_layout(g_nx, seed=42) # Kamada-Kawai for better structure sometimes
    
    plt.figure(figsize=(10, 8))
    
    # Draw nodes
    node_colors = ['lightblue'] * num_nodes
    node_colors[target_node_idx] = 'red' # Highlight target node
    nx.draw_networkx_nodes(g_nx, pos, node_color=node_colors, node_size=500)
    
    # Draw edges: highlight important ones
    edge_weights = [g_nx[u][v]['weight'] for u, v in g_nx.edges()]
    edge_alphas = [w if w > threshold else 0.1 for w in edge_weights] # Make less important edges more transparent
    edge_widths = [3*w if w > threshold else 0.5 for w in edge_weights]

    nx.draw_networkx_edges(g_nx, pos, width=edge_widths, alpha=edge_alphas, edge_color='gray')
    
    # Draw labels
    if node_labels is not None:
        labels = {i: f"{i}\n(L:{node_labels[i].item()})" for i in g_nx.nodes()}
    else:
        labels = {i: str(i) for i in g_nx.nodes()}
    nx.draw_networkx_labels(g_nx, pos, labels=labels, font_size=10)
    
    plt.title(title)
    plt.axis('off')
    plt.show()

# Visualize the explanation
plot_graph_with_masks(data.edge_index, edge_mask, target_node_idx, 
                        f'GNNExplainer Explanation for Node {target_node_idx} (Predicted: {pred_labels[target_node_idx].item()})',
                        node_labels=data.y, threshold=0.2) # Lower threshold for more visibility

In the plot above, the target node is highlighted (e.g., in red). Edges with higher importance scores from the `edge_mask` are shown as thicker and less transparent. This helps identify the computational subgraph that GNNExplainer deems important for the prediction of the target node.

### Node Feature Importance
The `node_feat_mask` indicates which input features are important. For the Karate Club dataset, features are one-hot encoded node identities, so the feature mask might not be as directly interpretable as in other contexts. However, if features had semantic meaning (e.g., age, degree in a social network), this mask would highlight which of those contributed most.

In [None]:
print(f"Node Feature Mask for Node {target_node_idx}:")
for i, val in enumerate(node_feat_mask):
    if val > 0.1: # Show features with some importance
        print(f"  Feature {i}: {val.item():.4f}")

## 6. Conclusion

This tutorial demonstrated the basic workflow of using `GNNExplainer` with Captum:
1. Training a GNN model.
2. Instantiating `GNNExplainer` with the trained model.
3. Calling the `attribute` method to get node feature and edge masks.
4. Visualizing these masks to interpret the model's prediction for a specific node.

GNNExplainer helps in understanding which parts of the graph (edges) and which node features are crucial for the GNN's decision-making process, enhancing transparency and trust in GNN models.