Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make models compatible to Captum #3990

Merged
merged 22 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions examples/captum_explainability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import os.path as osp

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from captum.attr import IntegratedGradients

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, GNNExplainer, to_captum

dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())
data = dataset[0]


class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = GCNConv(dataset.num_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)

def forward(self, x, edge_index):
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Captum expect log outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I wanted the gnn_explainer and the captum to be comparable, so I chose the exact same model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, what does that mean in particular? Is the usage of log_softmax correct?



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

for epoch in range(1, 201):
model.train()
optimizer.zero_grad()
log_logits = model(data.x, data.edge_index)
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()

node_idx = 10
target = int(data.y[node_idx])

# Edge explainability
# ===================

# Captum assumes that for all given input tensors, dimension 0 is
# equal to the number of samples. Therefore, we use unsqueeze(0).
captum_model = to_captum(model, mask_type='edge', node_idx=node_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target,
additional_forward_args=(data.x, data.edge_index),
internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

# Visualize absolute values of attributions with GNNExplainer visualizer
explainer = GNNExplainer(model) # TODO: Change to general Explainer visualizer
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge)
plt.show()

# Node explainability
# ===================

captum_model = to_captum(model, mask_type='node', node_idx=node_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target,
additional_forward_args=(data.edge_index),
internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()

# Visualize absolute values of attributions with GNNExplainer visualizer
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge,
node_alpha=ig_attr_node)
plt.show()

# Node and edge explainability
# ============================

captum_model = to_captum(model, mask_type='node_and_edge', node_idx=node_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node, ig_attr_edge = ig.attribute(
(data.x.unsqueeze(0), edge_mask.unsqueeze(0)), target=target,
additional_forward_args=(data.edge_index), internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
ig_attr_node /= ig_attr_node.max()
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge,
node_alpha=ig_attr_node)
plt.show()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
full_install_requires = [
'h5py',
'numba',
'captum',
'rdflib',
'trimesh',
'networkx',
Expand Down
111 changes: 111 additions & 0 deletions test/nn/models/test_explainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest
import torch

from torch_geometric.nn import GAT, GCN, to_captum

try:
from captum import attr # noqa
with_captum = True
except ImportError:
with_captum = False

x = torch.randn(8, 3, requires_grad=True)
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

GCN = GCN(3, 16, 2, 7, dropout=0.5)
GAT = GAT(3, 16, 2, 7, heads=2, concat=False)
mask_types = ['edge', 'node_and_edge', 'node']
methods = [
'Saliency',
'InputXGradient',
'Deconvolution',
'FeatureAblation',
'ShapleyValueSampling',
'IntegratedGradients',
'GradientShap',
'Occlusion',
'GuidedBackprop',
'KernelShap',
'Lime',
]


@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('model', [GCN, GAT])
@pytest.mark.parametrize('node_idx', [None, 1])
def test_to_captum(model, mask_type, node_idx):
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx)
pre_out = model(x, edge_index)
if mask_type == 'node':
mask = x * 0.0
out = captum_model(mask.unsqueeze(0), edge_index)
elif mask_type == 'edge':
mask = torch.ones(edge_index.shape[1], dtype=torch.float,
requires_grad=True) * 0.5
out = captum_model(mask.unsqueeze(0), x, edge_index)
elif mask_type == 'node_and_edge':
node_mask = x * 0.0
edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float,
requires_grad=True) * 0.5
out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0),
edge_index)

if node_idx is not None:
assert out.shape == (1, 7)
assert torch.any(out != pre_out[[node_idx]])
else:
assert out.shape == (8, 7)
assert torch.any(out != pre_out)


@pytest.mark.skipif(not with_captum, reason="no 'captum' package")
@pytest.mark.parametrize('mask_type', mask_types)
@pytest.mark.parametrize('method', methods)
def test_captum_attribution_methods(mask_type, method):
model = GCN
captum_model = to_captum(model, mask_type, 0)
input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float,
requires_grad=True)
explainer = getattr(attr, method)(captum_model)

if mask_type == 'node':
input = x.clone().unsqueeze(0)
additional_forward_args = (edge_index, )
sliding_window_shapes = (3, 3)
elif mask_type == 'edge':
input = input_mask
additional_forward_args = (x, edge_index)
sliding_window_shapes = (5, )
elif mask_type == 'node_and_edge':
input = (x.clone().unsqueeze(0), input_mask)
additional_forward_args = (edge_index, )
sliding_window_shapes = ((3, 3), (5, ))

if method == 'IntegratedGradients':
attributions, delta = explainer.attribute(
input, target=0, internal_batch_size=1,
additional_forward_args=additional_forward_args,
return_convergence_delta=True)
elif method == 'GradientShap':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
n_samples=1, additional_forward_args=additional_forward_args)
elif method == 'DeepLiftShap' or method == 'DeepLift':
attributions, delta = explainer.attribute(
input, target=0, return_convergence_delta=True, baselines=input,
additional_forward_args=additional_forward_args)
elif method == 'Occlusion':
attributions = explainer.attribute(
input, target=0, sliding_window_shapes=sliding_window_shapes,
additional_forward_args=additional_forward_args)
else:
attributions = explainer.attribute(
input, target=0, additional_forward_args=additional_forward_args)
if mask_type == 'node':
assert attributions.shape == (1, 8, 3)
elif mask_type == 'edge':
assert attributions.shape == (1, 14)
else:
assert attributions[0].shape == (1, 8, 3)
assert attributions[1].shape == (1, 14)
5 changes: 4 additions & 1 deletion torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, aggr: Optional[str] = "add",
self.__explain__ = False
self.__edge_mask__ = None
self.__loop_mask__ = None
self.__apply_sigmoid__ = True

# Hooks:
self._propagate_forward_pre_hooks = OrderedDict()
Expand Down Expand Up @@ -323,7 +324,9 @@ def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
# aggregate procedure since this allows us to inject the
# `edge_mask` into the message passing computation scheme.
if self.__explain__:
edge_mask = self.__edge_mask__.sigmoid()
edge_mask = self.__edge_mask__
if self.__apply_sigmoid__:
edge_mask = edge_mask.sigmoid()
# Some ops add self-loops to `edge_index`. We need to do
# the same for `edge_mask` (but do not train those).
if out.size(self.node_dim) != edge_mask.size(0):
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .schnet import SchNet
from .dimenet import DimeNet
from .gnn_explainer import GNNExplainer
from .explainer import to_captum
from .metapath2vec import MetaPath2Vec
from .deepgcn import DeepGCNLayer
from .tgn import TGNMemory
Expand Down Expand Up @@ -41,6 +42,7 @@
'SchNet',
'DimeNet',
'GNNExplainer',
'to_captum',
'MetaPath2Vec',
'DeepGCNLayer',
'TGNMemory',
Expand Down
Loading