-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make models compatible to Captum #3990
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
92cc648
Add PNA model
RBendias e16edd2
Merge branch 'pyg-team:master' into master
RBendias 27c4d6c
Merge branch 'pyg-team:master' into master
RBendias 070dfdf
Merge branch 'pyg-team:master' into master
RBendias a1c131f
Merge branch 'pyg-team:master' into master
RBendias cfaf789
Merge branch 'pyg-team:master' into master
RBendias ad9c572
Merge branch 'pyg-team:master' into master
RBendias c5c8fb4
Merge remote-tracking branch 'origin/master' into to_captum
467dc5b
Add to_captum for edge_masks
83e1830
Add node explainability with captum
09e692e
Merge branch 'master' into to_captum
rusty1s 784877d
Update
63c28a5
Add CaptumModel documentation
3cbb02f
Update torch_geometric/nn/models/explainer.py
RBendias a183d77
Update torch_geometric/nn/models/explainer.py
RBendias d404997
Update torch_geometric/nn/models/explainer.py
RBendias bd63e0c
Update torch_geometric/nn/models/explainer.py
RBendias 6659486
Add Captum dependency
e783219
Update captum_explainability.py
RBendias 96f9786
Update docs
3f2d5fb
Update docs
7874309
update doc
rusty1s File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import os.path as osp | ||
|
||
import matplotlib.pyplot as plt | ||
import torch | ||
import torch.nn.functional as F | ||
from captum.attr import IntegratedGradients | ||
|
||
import torch_geometric.transforms as T | ||
from torch_geometric.datasets import Planetoid | ||
from torch_geometric.nn import GCNConv, GNNExplainer, to_captum | ||
|
||
dataset = 'Cora' | ||
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid') | ||
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures()) | ||
data = dataset[0] | ||
|
||
|
||
class GCN(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = GCNConv(dataset.num_features, 16) | ||
self.conv2 = GCNConv(16, dataset.num_classes) | ||
|
||
def forward(self, x, edge_index): | ||
x = F.relu(self.conv1(x, edge_index)) | ||
x = F.dropout(x, training=self.training) | ||
x = self.conv2(x, edge_index) | ||
return F.log_softmax(x, dim=1) | ||
|
||
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
model = GCN().to(device) | ||
data = data.to(device) | ||
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) | ||
|
||
for epoch in range(1, 201): | ||
model.train() | ||
optimizer.zero_grad() | ||
log_logits = model(data.x, data.edge_index) | ||
loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask]) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
node_idx = 10 | ||
target = int(data.y[node_idx]) | ||
|
||
# Edge explainability | ||
# =================== | ||
|
||
# Captum assumes that for all given input tensors, dimension 0 is | ||
# equal to the number of samples. Therefore, we use unsqueeze(0). | ||
captum_model = to_captum(model, mask_type='edge', node_idx=node_idx) | ||
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device) | ||
|
||
ig = IntegratedGradients(captum_model) | ||
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=target, | ||
additional_forward_args=(data.x, data.edge_index), | ||
internal_batch_size=1) | ||
|
||
# Scale attributions to [0, 1]: | ||
ig_attr_edge = ig_attr_edge.squeeze(0).abs() | ||
ig_attr_edge /= ig_attr_edge.max() | ||
|
||
# Visualize absolute values of attributions with GNNExplainer visualizer | ||
explainer = GNNExplainer(model) # TODO: Change to general Explainer visualizer | ||
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge) | ||
plt.show() | ||
|
||
# Node explainability | ||
# =================== | ||
|
||
captum_model = to_captum(model, mask_type='node', node_idx=node_idx) | ||
|
||
ig = IntegratedGradients(captum_model) | ||
ig_attr_node = ig.attribute(data.x.unsqueeze(0), target=target, | ||
additional_forward_args=(data.edge_index), | ||
internal_batch_size=1) | ||
|
||
# Scale attributions to [0, 1]: | ||
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1) | ||
ig_attr_node /= ig_attr_node.max() | ||
|
||
# Visualize absolute values of attributions with GNNExplainer visualizer | ||
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge, | ||
node_alpha=ig_attr_node) | ||
plt.show() | ||
|
||
# Node and edge explainability | ||
# ============================ | ||
|
||
captum_model = to_captum(model, mask_type='node_and_edge', node_idx=node_idx) | ||
|
||
ig = IntegratedGradients(captum_model) | ||
ig_attr_node, ig_attr_edge = ig.attribute( | ||
(data.x.unsqueeze(0), edge_mask.unsqueeze(0)), target=target, | ||
additional_forward_args=(data.edge_index), internal_batch_size=1) | ||
|
||
# Scale attributions to [0, 1]: | ||
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1) | ||
ig_attr_node /= ig_attr_node.max() | ||
ig_attr_edge = ig_attr_edge.squeeze(0).abs() | ||
ig_attr_edge /= ig_attr_edge.max() | ||
|
||
ax, G = explainer.visualize_subgraph(node_idx, data.edge_index, ig_attr_edge, | ||
node_alpha=ig_attr_node) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
full_install_requires = [ | ||
'h5py', | ||
'numba', | ||
'captum', | ||
'rdflib', | ||
'trimesh', | ||
'networkx', | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import pytest | ||
import torch | ||
|
||
from torch_geometric.nn import GAT, GCN, to_captum | ||
|
||
try: | ||
from captum import attr # noqa | ||
with_captum = True | ||
except ImportError: | ||
with_captum = False | ||
|
||
x = torch.randn(8, 3, requires_grad=True) | ||
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], | ||
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) | ||
|
||
GCN = GCN(3, 16, 2, 7, dropout=0.5) | ||
GAT = GAT(3, 16, 2, 7, heads=2, concat=False) | ||
mask_types = ['edge', 'node_and_edge', 'node'] | ||
methods = [ | ||
'Saliency', | ||
'InputXGradient', | ||
'Deconvolution', | ||
'FeatureAblation', | ||
'ShapleyValueSampling', | ||
'IntegratedGradients', | ||
'GradientShap', | ||
'Occlusion', | ||
'GuidedBackprop', | ||
'KernelShap', | ||
'Lime', | ||
] | ||
|
||
|
||
@pytest.mark.parametrize('mask_type', mask_types) | ||
@pytest.mark.parametrize('model', [GCN, GAT]) | ||
@pytest.mark.parametrize('node_idx', [None, 1]) | ||
def test_to_captum(model, mask_type, node_idx): | ||
rusty1s marked this conversation as resolved.
Show resolved
Hide resolved
|
||
captum_model = to_captum(model, mask_type=mask_type, node_idx=node_idx) | ||
pre_out = model(x, edge_index) | ||
if mask_type == 'node': | ||
mask = x * 0.0 | ||
out = captum_model(mask.unsqueeze(0), edge_index) | ||
elif mask_type == 'edge': | ||
mask = torch.ones(edge_index.shape[1], dtype=torch.float, | ||
requires_grad=True) * 0.5 | ||
out = captum_model(mask.unsqueeze(0), x, edge_index) | ||
elif mask_type == 'node_and_edge': | ||
node_mask = x * 0.0 | ||
edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float, | ||
requires_grad=True) * 0.5 | ||
out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0), | ||
edge_index) | ||
|
||
if node_idx is not None: | ||
assert out.shape == (1, 7) | ||
assert torch.any(out != pre_out[[node_idx]]) | ||
else: | ||
assert out.shape == (8, 7) | ||
assert torch.any(out != pre_out) | ||
|
||
|
||
@pytest.mark.skipif(not with_captum, reason="no 'captum' package") | ||
@pytest.mark.parametrize('mask_type', mask_types) | ||
@pytest.mark.parametrize('method', methods) | ||
def test_captum_attribution_methods(mask_type, method): | ||
model = GCN | ||
captum_model = to_captum(model, mask_type, 0) | ||
input_mask = torch.ones((1, edge_index.shape[1]), dtype=torch.float, | ||
requires_grad=True) | ||
explainer = getattr(attr, method)(captum_model) | ||
|
||
if mask_type == 'node': | ||
input = x.clone().unsqueeze(0) | ||
additional_forward_args = (edge_index, ) | ||
sliding_window_shapes = (3, 3) | ||
elif mask_type == 'edge': | ||
input = input_mask | ||
additional_forward_args = (x, edge_index) | ||
sliding_window_shapes = (5, ) | ||
elif mask_type == 'node_and_edge': | ||
input = (x.clone().unsqueeze(0), input_mask) | ||
additional_forward_args = (edge_index, ) | ||
sliding_window_shapes = ((3, 3), (5, )) | ||
|
||
if method == 'IntegratedGradients': | ||
attributions, delta = explainer.attribute( | ||
input, target=0, internal_batch_size=1, | ||
additional_forward_args=additional_forward_args, | ||
return_convergence_delta=True) | ||
elif method == 'GradientShap': | ||
attributions, delta = explainer.attribute( | ||
input, target=0, return_convergence_delta=True, baselines=input, | ||
n_samples=1, additional_forward_args=additional_forward_args) | ||
elif method == 'DeepLiftShap' or method == 'DeepLift': | ||
attributions, delta = explainer.attribute( | ||
input, target=0, return_convergence_delta=True, baselines=input, | ||
additional_forward_args=additional_forward_args) | ||
elif method == 'Occlusion': | ||
attributions = explainer.attribute( | ||
input, target=0, sliding_window_shapes=sliding_window_shapes, | ||
additional_forward_args=additional_forward_args) | ||
else: | ||
attributions = explainer.attribute( | ||
input, target=0, additional_forward_args=additional_forward_args) | ||
if mask_type == 'node': | ||
assert attributions.shape == (1, 8, 3) | ||
elif mask_type == 'edge': | ||
assert attributions.shape == (1, 14) | ||
else: | ||
assert attributions[0].shape == (1, 8, 3) | ||
assert attributions[1].shape == (1, 14) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does Captum expect log outputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I wanted the gnn_explainer and the captum to be comparable, so I chose the exact same model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, what does that mean in particular? Is the usage of
log_softmax
correct?