Skip to content

Commit

Permalink
Merge pull request #2597 from wsad1/gnnexpupdate
Browse files Browse the repository at this point in the history
Extending Gnnexplainer for graph classification.
  • Loading branch information
rusty1s committed May 19, 2021
2 parents da2e1d0 + 65d3f6e commit ae783c0
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 7 deletions.
46 changes: 45 additions & 1 deletion test/nn/models/test_gnn_explainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import torch
from torch.nn import Sequential, Linear, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, GNNExplainer
from torch_geometric.nn import (GCNConv, GATConv, GNNExplainer,
global_add_pool, MessagePassing)


class GCN(torch.nn.Module):
Expand Down Expand Up @@ -58,3 +60,45 @@ def test_to_log_prob(model):

assert torch.allclose(raw_to_log(raw), prob_to_log(prob))
assert torch.allclose(prob_to_log(prob), log_to_log(log_prob))


def assert_edgemask_clear(model):
for layer in model.modules():
if isinstance(layer, MessagePassing):
assert ~layer.__explain__
assert layer.__edge_mask__ is None


class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(3, 16)
self.conv2 = GCNConv(16, 16)
self.fc1 = Sequential(Linear(16, 16), ReLU(), Dropout(0.2),
Linear(16, 7))

def forward(self, x, edge_index, batch, get_embedding=False):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
if get_embedding:
return x
x = global_add_pool(x, batch)
x = self.fc1(x)
return x.log_softmax(dim=1)


@pytest.mark.parametrize('model', [Net()])
def test_graph_explainer(model):
x = torch.randn(8, 3)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

explainer = GNNExplainer(model, log=False)

node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
assert_edgemask_clear(model)
assert node_feat_mask.size() == (x.size(1), )
assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
assert edge_mask.shape[0] == edge_index.shape[1]
assert edge_mask.max() <= 1 and edge_mask.min() >= 0
78 changes: 72 additions & 6 deletions torch_geometric/nn/models/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def __subgraph__(self, node_idx, x, edge_index, **kwargs):
return x, edge_index, mapping, edge_mask, kwargs

def __loss__(self, node_idx, log_logits, pred_label):
loss = -log_logits[node_idx, pred_label[node_idx]]
# node_idx is -1 for explaining graphs
loss = -log_logits[
node_idx, pred_label[node_idx]] if node_idx == -1 else -log_logits[
0, pred_label[0]]

m = self.edge_mask.sigmoid()
edge_reduce = getattr(torch, self.coeffs['edge_reduction'])
Expand All @@ -145,6 +148,61 @@ def __to_log_prob__(self, x: torch.Tensor) -> torch.Tensor:
x = x.log() if self.return_type == 'prob' else x
return x

def explain_graph(self, x, edge_index, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for a graph.
Args:
x (Tensor): The node feature matrix.
edge_index (LongTensor): The edge indices.
**kwargs (optional): Additional arguments passed to the GNN module.
:rtype: (:class:`Tensor`, :class:`Tensor`)
"""

self.model.eval()
self.__clear_masks__()

# all nodes belong to same graph
batch = torch.zeros(x.shape[0], dtype=int, device=x.device)

# Get the initial prediction.
with torch.no_grad():
out = self.model(x=x, edge_index=edge_index, batch=batch, **kwargs)
log_logits = self.__to_log_prob__(out)
pred_label = log_logits.argmax(dim=-1)

self.__set_masks__(x, edge_index)
self.to(x.device)

optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
lr=self.lr)

if self.log: # pragma: no cover
pbar = tqdm(total=self.epochs)
pbar.set_description('Explain graph')

for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.view(1, -1).sigmoid()
out = self.model(x=h, edge_index=edge_index, batch=batch, **kwargs)
log_logits = self.__to_log_prob__(out)
loss = self.__loss__(-1, log_logits, pred_label)
loss.backward()
optimizer.step()

if self.log: # pragma: no cover
pbar.update(1)

if self.log: # pragma: no cover
pbar.close()

node_feat_mask = self.node_feat_mask.detach().sigmoid()
edge_mask = self.edge_mask.detach().sigmoid()

self.__clear_masks__()
return node_feat_mask, edge_mask

def explain_node(self, node_idx, x, edge_index, **kwargs):
r"""Learns and returns a node feature mask and an edge mask that play a
crucial role to explain the prediction made by the GNN for node
Expand Down Expand Up @@ -209,11 +267,12 @@ def explain_node(self, node_idx, x, edge_index, **kwargs):

def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
threshold=None, **kwargs):
r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
r"""Visualizes the subgraph given an edge mask
:attr:`edge_mask`.
Args:
node_idx (int): The node id to explain.
Set to :obj:`-1` to explain graph.
edge_index (LongTensor): The edge indices.
edge_mask (Tensor): The edge mask.
y (Tensor, optional): The ground-truth node-prediction labels used
Expand All @@ -230,10 +289,17 @@ def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
import matplotlib.pyplot as plt
assert edge_mask.size(0) == edge_index.size(1)

# Only operate on a k-hop subgraph around `node_idx`.
subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
node_idx, self.num_hops, edge_index, relabel_nodes=True,
num_nodes=None, flow=self.__flow__())
if node_idx == -1:
hard_edge_mask = torch.BoolTensor([True] * edge_index.size(1),
device=edge_mask.device)
subset = torch.arange(
edge_index.max() + 1,
device=edge_index.device if y is None else y.device)
else:
# Only operate on a k-hop subgraph around `node_idx`.
subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
node_idx, self.num_hops, edge_index, relabel_nodes=True,
num_nodes=None, flow=self.__flow__())

edge_mask = edge_mask[hard_edge_mask]

Expand Down

0 comments on commit ae783c0

Please sign in to comment.