-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make models compatible to Captum #3990
Merged
Merged
Changes from 13 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
92cc648
Add PNA model
RBendias e16edd2
Merge branch 'pyg-team:master' into master
RBendias 27c4d6c
Merge branch 'pyg-team:master' into master
RBendias 070dfdf
Merge branch 'pyg-team:master' into master
RBendias a1c131f
Merge branch 'pyg-team:master' into master
RBendias cfaf789
Merge branch 'pyg-team:master' into master
RBendias ad9c572
Merge branch 'pyg-team:master' into master
RBendias c5c8fb4
Merge remote-tracking branch 'origin/master' into to_captum
467dc5b
Add to_captum for edge_masks
83e1830
Add node explainability with captum
09e692e
Merge branch 'master' into to_captum
rusty1s 784877d
Update
63c28a5
Add CaptumModel documentation
3cbb02f
Update torch_geometric/nn/models/explainer.py
RBendias a183d77
Update torch_geometric/nn/models/explainer.py
RBendias d404997
Update torch_geometric/nn/models/explainer.py
RBendias bd63e0c
Update torch_geometric/nn/models/explainer.py
RBendias 6659486
Add Captum dependency
e783219
Update captum_explainability.py
RBendias 96f9786
Update docs
3f2d5fb
Update docs
7874309
update doc
rusty1s File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import os.path as osp | ||
|
||
import matplotlib.pyplot as plt | ||
import torch | ||
import torch.nn.functional as F | ||
from captum.attr import IntegratedGradients | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.nn import GCNConv, GNNExplainer, to_captum | ||
|
||
dataset = 'Cora' | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
transform = T.NormalizeFeatures() | ||
dataset = Planetoid(path, dataset, transform=transform) | ||
data = dataset[0] | ||
|
||
|
||
class GCN(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = GCNConv(dataset.num_features, 16) | ||
self.conv2 = GCNConv(16, dataset.num_classes) | ||
|
||
def forward(self, x, edge_index): | ||
x = F.relu(self.conv1(x, edge_index)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.conv2(x, edge_index) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = GCN().to(device) | ||
data = data.to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | ||
x, edge_index = data.x, data.edge_index | ||
|
||
for epoch in range(1, 201): | ||
model.train() | ||
optimizer.zero_grad() | ||
log_logits = model(x, edge_index) | ||
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
node_idx = 10 | ||
target = int(data.y[node_idx]) | ||
|
||
# Edge explainability for node 10 | ||
# Captum assumes that for all given input tensors, dimension 0 is | ||
# equal to the number of samples. Therefore, we use unsqueeze(0). | ||
input_mask = torch.ones(data.num_edges, dtype=torch.float, requires_grad=True, | ||
device=device) | ||
captum_model = to_captum(model, mask_type='edge', node_idx=node_idx) | ||
ig = IntegratedGradients(captum_model) | ||
ig_attr = ig.attribute(input_mask.unsqueeze(0), target=target, | ||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||
additional_forward_args=(x, edge_index), | ||
internal_batch_size=1) | ||
|
||
# Scale attributions to [0, 1] | ||
ig_attr = ig_attr.squeeze(0).abs() / ig_attr.abs().max() | ||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Visualize attributions with GNNExplainer visualizer | ||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: Change to general Explainer visualizer | ||
explainer = GNNExplainer(model) | ||
ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr) | ||
plt.show() | ||
|
||
# Node explainability for node 10 | ||
captum_model = to_captum(model, mask_type='node', node_idx=node_idx) | ||
ig = IntegratedGradients(captum_model) | ||
ig_attr_node = ig.attribute(x.unsqueeze(0), target=target, | ||
additional_forward_args=(edge_index), | ||
internal_batch_size=1) | ||
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1) | ||
ig_attr_node /= ig_attr_node.max() | ||
|
||
ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr, | ||
node_alpha=ig_attr_node) | ||
plt.show() | ||
|
||
# Node and edge explainability for node 10 | ||
captum_model = to_captum(model, mask_type='node_and_edge', node_idx=node_idx) | ||
ig = IntegratedGradients(captum_model) | ||
ig_attr_node_and_edge = ig.attribute( | ||
(x.unsqueeze(0), input_mask.unsqueeze(0)), target=target, | ||
additional_forward_args=(edge_index), internal_batch_size=1) | ||
ig_attr_node = ig_attr_node_and_edge[0].squeeze(0).abs().sum(dim=1) | ||
ig_attr_node /= ig_attr_node.max() | ||
ig_attr_edge = ig_attr_node_and_edge[1].squeeze(0).abs() | ||
ig_attr_edge /= ig_attr_edge.max() | ||
|
||
ax, G = explainer.visualize_subgraph(node_idx, edge_index, ig_attr_edge, | ||
node_alpha=ig_attr_node) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn import GAT, GCN, to_captum | ||
|
||
GCN = GCN(3, 16, 2, 7, dropout=0.5) | ||
GAT = GAT(3, 16, 2, 7, heads=2, concat=False) | ||
|
||
mask_type = ['edge', 'node_and_edge', 'node'] | ||
|
||
|
||
@pytest.mark.parametrize('mask_type', mask_type) | ||
@pytest.mark.parametrize('model', [GCN, GAT]) | ||
@pytest.mark.parametrize('node_idx', [None, 1]) | ||
def test_to_captum(model, mask_type, node_idx): | ||
rusty1s marked this conversation as resolved.
Show resolved
Hide resolved
|
||
captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx) | ||
|
||
x = torch.randn(8, 3) | ||
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], | ||
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) | ||
pre_out = model(x, edge_index) | ||
|
||
captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx) | ||
|
||
if mask_type == 'node': | ||
mask = x * 0.0 | ||
out = captum_model(mask.unsqueeze(0), edge_index) | ||
elif mask_type == 'edge': | ||
mask = torch.ones(edge_index.shape[1], dtype=torch.float) * 0.5 | ||
out = captum_model(mask.unsqueeze(0), x, edge_index) | ||
elif mask_type == 'node_and_edge': | ||
node_mask = x * 0.0 | ||
edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float) * 0.5 | ||
out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0), | ||
edge_index) | ||
|
||
if node_idx is not None: | ||
assert out.shape == (1, 7) | ||
assert torch.any(out != pre_out[[node_idx]]) | ||
else: | ||
assert out.shape == (8, 7) | ||
assert torch.any(out != pre_out) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,113 @@ | ||||||
from typing import Optional | ||||||
|
||||||
import torch | ||||||
from torch import Tensor | ||||||
|
||||||
from torch_geometric.nn import MessagePassing | ||||||
|
||||||
|
||||||
def set_masks(model: torch.nn.Module, mask: Tensor, edge_index: Tensor, | ||||||
apply_sigmoid: bool = True): | ||||||
"""Apply mask to every graph layer in the model.""" | ||||||
loop_mask = edge_index[0] != edge_index[1] | ||||||
|
||||||
# Loop over layers and set masks on MessagePassing layers | ||||||
for module in model.modules(): | ||||||
if isinstance(module, MessagePassing): | ||||||
module.__explain__ = True | ||||||
module.__edge_mask__ = mask | ||||||
module.__loop_mask__ = loop_mask | ||||||
module.__apply_sigmoid__ = apply_sigmoid | ||||||
|
||||||
|
||||||
def clear_masks(model: torch.nn.Module): | ||||||
"""Clear all masks from the model.""" | ||||||
for module in model.modules(): | ||||||
if isinstance(module, MessagePassing): | ||||||
module.__explain__ = False | ||||||
module.__edge_mask__ = None | ||||||
module.__loop_mask__ = None | ||||||
module.__apply_sigmoid__ = True | ||||||
return module | ||||||
|
||||||
|
||||||
class CaptumModel(torch.nn.Module): | ||||||
r"""Model with forward function that can be easily used for | ||||||
explainability with `Captum.ai <https://captum.ai/>`_. | ||||||
|
||||||
Args: | ||||||
model (torch.nn.Module): Model to be explained. | ||||||
mask_type (str): Denotes the type of mask to be created with a Captum | ||||||
explainer. Valid inputs are :obj:`'edge'`, :obj:`'node'`, and | ||||||
:obj:`'node_and_edge'`. The input for the forward function with | ||||||
mask_type :obj:`'edge'` should be an edge_mask tensor of shape | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
(1, num_edges), :obj:`x` and :obj:`edge_index`. The input for the | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
forward function with mask_type :obj:`'node'` should be a | ||||||
node_input of shape (1, num_nodes, num_features) and | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
:obj:`edge_index`. The input for the forward function with | ||||||
mask_type :obj:`'node_and_edge'` should be a node_input tensor of | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
shape (1, num_nodes, num_features), an edge_mask tensor of shape | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
(1, num_edges), and :obj:`edge_index`. (default: :obj:`'edge'`) | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
node_idx (int, optional): Index of the node to be explained. With | ||||||
:obj:`'node_idx'` set, the forward function will return the output | ||||||
of the model for the node at the index specified. | ||||||
(default: :obj:`None`) | ||||||
""" | ||||||
def __init__(self, model: torch.nn.Module, mask_type: str = "edge", | ||||||
node_idx: Optional[int] = None): | ||||||
super().__init__() | ||||||
|
||||||
assert mask_type in ['edge', 'node', 'node_and_edge'] | ||||||
|
||||||
self.mask_type = mask_type | ||||||
self.model = model | ||||||
self.node_idx = node_idx | ||||||
|
||||||
def forward(self, mask, *args): | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""""" | ||||||
# The mask tensor, which comes from Captum's attribution methods, | ||||||
# contains the number of samples in dimension 0. Since we are | ||||||
# working with only one sample, we squeeze the tensors below. | ||||||
assert mask.shape[0] == 1, "Dimension 0 of input should be 1" | ||||||
if self.mask_type == "edge": | ||||||
assert len(args) >= 2, "Expects at least x and edge_index as args." | ||||||
if self.mask_type == "node": | ||||||
assert len(args) >= 1, "Expects at least edge_index as args." | ||||||
if self.mask_type == "node_and_edge": | ||||||
assert args[0].shape[0] == 1, "Dimension 0 of input should be 1" | ||||||
assert len(args[1:]) >= 1, "Expects at least edge_index as args." | ||||||
|
||||||
# Set edge mask | ||||||
if self.mask_type == 'edge': | ||||||
set_masks(self.model, mask.squeeze(0), args[1], | ||||||
rusty1s marked this conversation as resolved.
Show resolved
Hide resolved
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
apply_sigmoid=False) | ||||||
elif self.mask_type == 'node_and_edge': | ||||||
set_masks(self.model, args[0].squeeze(0), args[1], | ||||||
apply_sigmoid=False) | ||||||
args = args[1:] | ||||||
|
||||||
# Edge mask | ||||||
if self.mask_type == 'edge': | ||||||
x = self.model(*args) | ||||||
|
||||||
# Node mask | ||||||
elif self.mask_type == 'node': | ||||||
x = self.model(mask.squeeze(0), *args) | ||||||
|
||||||
# Node and edge mask | ||||||
else: | ||||||
x = self.model(mask[0], *args) | ||||||
|
||||||
# Clear mask | ||||||
if self.mask_type in ['edge', 'node_and_edge']: | ||||||
clear_masks(self.model) | ||||||
|
||||||
if self.node_idx is not None: | ||||||
x = x[self.node_idx].unsqueeze(0) | ||||||
return x | ||||||
|
||||||
|
||||||
def to_captum(model: torch.nn.Module, mask_type: str = "edge", | ||||||
RBendias marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
node_idx: Optional[int] = None) -> torch.nn.Module: | ||||||
"""Convert a model to a model that can be used for Captum explainers.""" | ||||||
return CaptumModel(model, mask_type, node_idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Captum expect log outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I wanted the gnn_explainer and the captum to be comparable, so I chose the exact same model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, what does that mean in particular? Is the usage of
log_softmax
correct?