Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make models compatible to Captum #3990

Merged
merged 22 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions examples/captum_explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os.path as osp

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GNNExplainer, to_captum

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
transform = T.NormalizeFeatures()
dataset = Planetoid(path, dataset, transform=transform)
data = dataset[0]


class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Captum expect log outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I wanted the gnn_explainer and the captum to be comparable, so I chose the exact same model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, what does that mean in particular? Is the usage of log_softmax correct?



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index = data.x, data.edge_index

for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
log_logits = model(x, edge_index)
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

node_idx = 10
target = int(data.y[node_idx])

# Edge explainability for node 10
# Captum assumes that for all given input tensors, dimension 0 is
# equal to the number of samples. Therefore, we use unsqueeze(0).
input_mask = torch.ones(data.num_edges, dtype=torch.float, requires_grad=True,
device=device)
captum_model = to_captum(model, mask_type='edge', node_idx=node_idx)
ig = IntegratedGradients(captum_model)
ig_attr = ig.attribute(input_mask.unsqueeze(0), target=target,
RBendias marked this conversation as resolved.
Show resolved Hide resolved
additional_forward_args=(x, edge_index),
internal_batch_size=1)

# Scale attributions to [0, 1]
ig_attr = ig_attr.squeeze(0).abs() / ig_attr.abs().max()
RBendias marked this conversation as resolved.
Show resolved Hide resolved

# Visualize attributions with GNNExplainer visualizer
RBendias marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Change to general Explainer visualizer
explainer = GNNExplainer(model)
ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr)
plt.show()

# Node explainability for node 10
captum_model = to_captum(model, mask_type='node', node_idx=node_idx)
ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(x.unsqueeze(0), target=target,
additional_forward_args=(edge_index),
internal_batch_size=1)
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()

ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr,
node_alpha=ig_attr_node)
plt.show()

# Node and edge explainability for node 10
captum_model = to_captum(model, mask_type='node_and_edge', node_idx=node_idx)
ig = IntegratedGradients(captum_model)
ig_attr_node_and_edge = ig.attribute(
(x.unsqueeze(0), input_mask.unsqueeze(0)), target=target,
additional_forward_args=(edge_index), internal_batch_size=1)
ig_attr_node = ig_attr_node_and_edge[0].squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()
ig_attr_edge = ig_attr_node_and_edge[1].squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr_edge,
node_alpha=ig_attr_node)
plt.show()
42 changes: 42 additions & 0 deletions test/nn/models/test_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest
import torch

from torch_geometric.nn import GAT, GCN, to_captum

GCN = GCN(3, 16, 2, 7, dropout=0.5)
GAT = GAT(3, 16, 2, 7, heads=2, concat=False)

mask_type = ['edge', 'node_and_edge', 'node']


@pytest.mark.parametrize('mask_type', mask_type)
@pytest.mark.parametrize('model', [GCN, GAT])
@pytest.mark.parametrize('node_idx', [None, 1])
def test_to_captum(model, mask_type, node_idx):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx)

x = torch.randn(8, 3)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
pre_out = model(x, edge_index)

captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx)

if mask_type == 'node':
mask = x * 0.0
out = captum_model(mask.unsqueeze(0), edge_index)
elif mask_type == 'edge':
mask = torch.ones(edge_index.shape[1], dtype=torch.float) * 0.5
out = captum_model(mask.unsqueeze(0), x, edge_index)
elif mask_type == 'node_and_edge':
node_mask = x * 0.0
edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float) * 0.5
out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0),
edge_index)

if node_idx is not None:
assert out.shape == (1, 7)
assert torch.any(out != pre_out[[node_idx]])
else:
assert out.shape == (8, 7)
assert torch.any(out != pre_out)
5 changes: 4 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, aggr: Optional[str] = "add",
self.__explain__ = False
self.__edge_mask__ = None
self.__loop_mask__ = None
self.__apply_sigmoid__ = True

# Hooks:
self._propagate_forward_pre_hooks = OrderedDict()
Expand Down Expand Up @@ -323,7 +324,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
# aggregate procedure since this allows us to inject the
# `edge_mask` into the message passing computation scheme.
if self.__explain__:
edge_mask = self.__edge_mask__.sigmoid()
edge_mask = self.__edge_mask__
if self.__apply_sigmoid__:
edge_mask = edge_mask.sigmoid()
# Some ops add self-loops to `edge_index`. We need to do
# the same for `edge_mask` (but do not train those).
if out.size(self.node_dim) != edge_mask.size(0):
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .rect import RECT_L
from .linkx import LINKX
from .lightgcn import LightGCN
from .explainer import to_captum

__all__ = [
'MLP',
Expand Down Expand Up @@ -50,6 +51,7 @@
'RECT_L',
'LINKX',
'LightGCN',
'to_captum',
]

classes = __all__
113 changes: 113 additions & 0 deletions torch_geometric/nn/models/explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional

import torch
from torch import Tensor

from torch_geometric.nn import MessagePassing


def set_masks(model: torch.nn.Module, mask: Tensor, edge_index: Tensor,
apply_sigmoid: bool = True):
"""Apply mask to every graph layer in the model."""
loop_mask = edge_index[0] != edge_index[1]

# Loop over layers and set masks on MessagePassing layers
for module in model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = mask
module.__loop_mask__ = loop_mask
module.__apply_sigmoid__ = apply_sigmoid


def clear_masks(model: torch.nn.Module):
"""Clear all masks from the model."""
for module in model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = False
module.__edge_mask__ = None
module.__loop_mask__ = None
module.__apply_sigmoid__ = True
return module


class CaptumModel(torch.nn.Module):
r"""Model with forward function that can be easily used for
explainability with `Captum.ai <https://captum.ai/>`_.

Args:
model (torch.nn.Module): Model to be explained.
mask_type (str): Denotes the type of mask to be created with a Captum
explainer. Valid inputs are :obj:`'edge'`, :obj:`'node'`, and
:obj:`'node_and_edge'`. The input for the forward function with
mask_type :obj:`'edge'` should be an edge_mask tensor of shape
RBendias marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
mask_type :obj:`'edge'` should be an edge_mask tensor of shape
mask_type :obj:`'edge'` should be an edge mask tensor of shape

(1, num_edges), :obj:`x` and :obj:`edge_index`. The input for the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(1, num_edges), :obj:`x` and :obj:`edge_index`. The input for the
:obj:`[1, num_edges]`, :obj:`x` and :obj:`edge_index`. The input for the

forward function with mask_type :obj:`'node'` should be a
node_input of shape (1, num_nodes, num_features) and
RBendias marked this conversation as resolved.
Show resolved Hide resolved
RBendias marked this conversation as resolved.
Show resolved Hide resolved
:obj:`edge_index`. The input for the forward function with
mask_type :obj:`'node_and_edge'` should be a node_input tensor of
RBendias marked this conversation as resolved.
Show resolved Hide resolved
shape (1, num_nodes, num_features), an edge_mask tensor of shape
RBendias marked this conversation as resolved.
Show resolved Hide resolved
(1, num_edges), and :obj:`edge_index`. (default: :obj:`'edge'`)
RBendias marked this conversation as resolved.
Show resolved Hide resolved
RBendias marked this conversation as resolved.
Show resolved Hide resolved
node_idx (int, optional): Index of the node to be explained. With
:obj:`'node_idx'` set, the forward function will return the output
of the model for the node at the index specified.
(default: :obj:`None`)
"""
def __init__(self, model: torch.nn.Module, mask_type: str = "edge",
node_idx: Optional[int] = None):
super().__init__()

assert mask_type in ['edge', 'node', 'node_and_edge']

self.mask_type = mask_type
self.model = model
self.node_idx = node_idx

def forward(self, mask, *args):
RBendias marked this conversation as resolved.
Show resolved Hide resolved
""""""
# The mask tensor, which comes from Captum's attribution methods,
# contains the number of samples in dimension 0. Since we are
# working with only one sample, we squeeze the tensors below.
assert mask.shape[0] == 1, "Dimension 0 of input should be 1"
if self.mask_type == "edge":
assert len(args) >= 2, "Expects at least x and edge_index as args."
if self.mask_type == "node":
assert len(args) >= 1, "Expects at least edge_index as args."
if self.mask_type == "node_and_edge":
assert args[0].shape[0] == 1, "Dimension 0 of input should be 1"
assert len(args[1:]) >= 1, "Expects at least edge_index as args."

# Set edge mask
if self.mask_type == 'edge':
set_masks(self.model, mask.squeeze(0), args[1],
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
RBendias marked this conversation as resolved.
Show resolved Hide resolved
apply_sigmoid=False)
elif self.mask_type == 'node_and_edge':
set_masks(self.model, args[0].squeeze(0), args[1],
apply_sigmoid=False)
args = args[1:]

# Edge mask
if self.mask_type == 'edge':
x = self.model(*args)

# Node mask
elif self.mask_type == 'node':
x = self.model(mask.squeeze(0), *args)

# Node and edge mask
else:
x = self.model(mask[0], *args)

# Clear mask
if self.mask_type in ['edge', 'node_and_edge']:
clear_masks(self.model)

if self.node_idx is not None:
x = x[self.node_idx].unsqueeze(0)
return x


def to_captum(model: torch.nn.Module, mask_type: str = "edge",
RBendias marked this conversation as resolved.
Show resolved Hide resolved
node_idx: Optional[int] = None) -> torch.nn.Module:
"""Convert a model to a model that can be used for Captum explainers."""
return CaptumModel(model, mask_type, node_idx)
14 changes: 3 additions & 11 deletions torch_geometric/nn/models/gnn_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.models.explainer import clear_masks, set_masks
from torch_geometric.utils import k_hop_subgraph, to_networkx

EPS = 1e-15
Expand Down Expand Up @@ -98,20 +99,11 @@ def __set_masks__(self, x, edge_index, init="normal"):
if not self.allow_edge_mask:
self.edge_mask.requires_grad_(False)
self.edge_mask.fill_(float('inf')) # `sigmoid()` returns `1`.
self.loop_mask = edge_index[0] != edge_index[1]

for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = True
module.__edge_mask__ = self.edge_mask
module.__loop_mask__ = self.loop_mask
set_masks(self.model, self.edge_mask, edge_index, apply_sigmoid=True)

def __clear_masks__(self):
for module in self.model.modules():
if isinstance(module, MessagePassing):
module.__explain__ = False
module.__edge_mask__ = None
module.__loop_mask__ = None
module = clear_masks(self.model)
self.node_feat_masks = None
self.edge_mask = None
module.loop_mask = None
Expand Down