Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extending Gnnexplainer for graph classification. #2597

Merged
merged 5 commits into from
May 19, 2021

Conversation

wsad1
Copy link
Member

@wsad1 wsad1 commented May 18, 2021

#1374
Updated Gnnexplainer to support graph classification. Following are the key updates.

  1. Added explain_graph which should be called for graph classification. This has lots of code similar to explain_node. Another solution would be to have a explain function that handles both graph or node explanation, however that would mean explain_node would have to be retired/depreciated, let me know if you feel that's a better solution.
  2. In visualize_subgraph setting node_idx to -1 implies a graph classification task.

Further i'll add an example for this in a day or two. But feel free to review the code.

@rusty1s
Copy link
Member

rusty1s commented May 19, 2021

This is super awesome. Thanks a lot! I already fixed the failing PairNorm test in master.

@rusty1s rusty1s merged commit ae783c0 into pyg-team:master May 19, 2021
@Luoyunsong
Copy link

Hello , I have got an erro "type object 'GNNExplainer' has no attribute 'explain_graph'"...Is that a wrong version for me?

@rusty1s
Copy link
Member

rusty1s commented May 19, 2021

This feature just got merged. You need to install PyG from master to access it:

pip install git+https://github.com/rusty1s/pytorch_geometric.git

@Luoyunsong
Copy link

Thanks a lot!

@Luoyunsong
Copy link

My code:

x1, edge_index1 = testing_dataset[1].x, testing_dataset[1].edge_index
explainer = GNNExplainer(model, epochs=200)
node_idx = -1
graph_feat_mask, edge_mask = explainer.explain_graph( x1, edge_index1)
ax, G = explainer.visualize_subgraph(node_idx,edge_index1, edge_mask, y=testing_dataset[1].y)
plt.show()

And the erro:

Explain graph: 100%|██████████| 200/200 [00:00<00:00, 203.63it/s]

IndexError Traceback (most recent call last)
in ()
2 node_idx = -1
3 graph_feat_mask, edge_mask = explainer.explain_graph( x1, edge_index1)
----> 4 ax, G = explainer.visualize_subgraph(node_idx,edge_index1, edge_mask, y=testing_dataset[1].y)
5 plt.show()

/usr/local/lib/python3.7/dist-packages/torch_geometric/nn/models/gnn_explainer.py in visualize_subgraph(self, node_idx, edge_index, edge_mask, y, threshold, **kwargs)
311 device=edge_index.device)
312 else:
--> 313 y = y[subset].to(torch.float) / y.max().item()
314
315 data = Data(edge_index=edge_index, att=edge_mask, y=y,

IndexError: too many indices for tensor of dimension 0

The shape of my edge_index0 is "torch.Size([2, 450])",I really dont know how to set these parameters

@wsad1
Copy link
Member Author

wsad1 commented May 20, 2021

@Luoyunsong , could you please share

  1. the shape of testing_dataset[1].y and testing_dataset[1].x
  2. testing_dataset[1].edge_index.max().

This should help me debug this better.

@Luoyunsong
Copy link

I was doing a brain network classfication.
The shape of testing_dataset[1].y is torch.Size([]),which is the label of a brain network graph,'0' and '1'.
The shape of testing_dataset[1].x is torch.Size([90, 90]).
The testing_dataset[1].edge_index.max() is tensor(89),and its shape is torch.Size([])
Thanks for your help

@wsad1
Copy link
Member Author

wsad1 commented May 21, 2021

@Luoyunsong, thanks for bringing this up. This is a bug, i'll create a PR to fix this asap. In the meanwhile could to try setting y to None to get around the bug explainer.visualize_subgraph(node_idx,edge_index1, edge_mask, y=None).

@wsad1 wsad1 deleted the gnnexpupdate branch May 21, 2021 04:49
@Luoyunsong
Copy link

Thanks for your help and its a perfect work

@wsad1
Copy link
Member Author

wsad1 commented Jun 17, 2021

@panisson are you referring to this line in the loss function.

loss = -log_logits[ node_idx, pred_label[node_idx]] if node_idx == -1 else -log_logits[0,pred_label[0]]                                                                                                                             

@rusty1s
Copy link
Member

rusty1s commented Jun 17, 2021

Yes, this seems to be indeed a crucial bug. We might need to hot-fix the new release.

@wsad1
Copy link
Member Author

wsad1 commented Jun 18, 2021

Thanks @panisson for catching this. @rusty1s fixed it 35f5ac3.

@htlee6
Copy link

htlee6 commented Mar 28, 2023

Hi all! Is this feature still available in version 2.3.0? Thanks

@rusty1s
Copy link
Member

rusty1s commented Mar 28, 2023

Yes, graph classification within GNNExplainer is fully supported.

@Paul1086
Copy link

Paul1086 commented Sep 5, 2023

Hi, I am a beginner at GNN. Just have a question, Does the node_feat_mask of explain_graph directly give the node feature importance for the entire graph classification? Or, do we need to follow other steps to find the feature importance after obtaining node_feat_mask?

@rusty1s
Copy link
Member

rusty1s commented Sep 5, 2023

You should be able to directly use it, and it will contain the feature importance of the entire graph.

@sumone-compbio
Copy link

Hi, is there any tutorial or example code available for explain_graph? Thank you.

@rusty1s
Copy link
Member

rusty1s commented Jan 22, 2024

No example yet, but hopefully the test case gets you going: https://github.com/pyg-team/pytorch_geometric/blob/master/test/explain/algorithm/test_gnn_explainer.py#L80

@sumone-compbio
Copy link

sumone-compbio commented Jan 23, 2024

@rusty1s thanks a lot. I'm new to XAI methods. Can you please tell me what would be the input here? The training set on which the model (e.g. GCNConv) is trained (to see how the model can differentiate between 2 classes for a binary classification problem) or the held-out test set?
What I want is to find common substructures in graphs that contribute to its behavior (e.g. antibiotic or not).

@rusty1s
Copy link
Member

rusty1s commented Jan 23, 2024

It depends on what explanations you want to receive (explaining training data or explaining hold out data). In general, I think it is more common to run explanations on hold out data. You can then find common substructures by looking at the edge attribution.

@sumone-compbio
Copy link

sumone-compbio commented Feb 16, 2024

@rusty1s thank you I'm still working on it. I got the error: 'GNNExplainer' object has no attribute 'explain_graph'. Even though I installed PyG from the master link above. Am I missing something? I am using torch 2.2.0 and torch_geometric 2.5.0

@rusty1s
Copy link
Member

rusty1s commented Feb 17, 2024

Are you using GNNExplainer from torch_geometric.explain?

@sumone-compbio
Copy link

sumone-compbio commented Feb 20, 2024

@rusty1s yes. This is the code I am running:
`from torch_geometric.explain import GNNExplainer

x1, edge_index1 = test_dataset[1].x, test_dataset[1].edge_index

explainer = GNNExplainer(model=model, epochs=400, lr=0.0001)

graph_feat_mask, edge_mask = explainer.explain_graph( x1, edge_index1)`

Following is the error:
AttributeError Traceback (most recent call last)
Cell In[40], line 2
1 node_idx = -1
----> 2 graph_feat_mask, edge_mask = explainer.explain_graph( x1, edge_index1)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1688, in Module.getattr(self, name)
1686 if name in modules:
1687 return modules[name]
-> 1688 raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")

AttributeError: 'GNNExplainer' object has no attribute 'explain_graph'

The model I am using is GraphConv from PyG. I couldn't fix this error.

@rusty1s
Copy link
Member

rusty1s commented Feb 20, 2024

That would be an incorrect use of GNNExplainer. You need to wrap it inside an Explainer module (take a look at our examples on how to do this).

@sumone-compbio
Copy link

sorry, I couldn't understand by looking at the example.
I was referring to this:
test/nn/models/test_gnn_explainer.py in 7512004.

For other code in https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer.py (I assume for node classification) I didn't understand what parameters are different for graph-level classification.

@rusty1s
Copy link
Member

rusty1s commented Feb 21, 2024

Here is how you would use GNNExplainer now:

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='graph',
        return_type='log_probs',
    ),
)

@sumone-compbio
Copy link

Thank you so much for the clarification. I just saw that explain_graph was in the deprecated version. However, the batch_index argument of my model's forward function is causing the problem. I think I didn't understand how to load a test graph in the explainer. Below is the error:

Cell In[108], line 1
----> 1 explanation = explainer(x=x0, edge_index=edge_index0, target=None)
2 print(f'Generated explanations in {explanation.available_explanations}')

File ~/.local/lib/python3.10/site-packages/torch_geometric/explain/explainer.py:196, in Explainer.call(self, x, edge_index, target, index, **kwargs)
192 if target is not None:
193 warnings.warn(
194 f"The 'target' should not be provided for the explanation "
195 f"type '{self.explanation_type.value}'")
--> 196 prediction = self.get_prediction(x, edge_index, **kwargs)
197 target = self.get_target(prediction)
199 if isinstance(index, int):

File ~/.local/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator..decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch_geometric/explain/explainer.py:115, in Explainer.get_prediction(self, *args, **kwargs)
112 self.model.eval()
114 with torch.no_grad():
--> 115 out = self.model(*args, **kwargs)
...
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None

TypeError: GCN.forward() missing 1 required positional argument: 'batch_index'

@rusty1s
Copy link
Member

rusty1s commented Feb 23, 2024

If batch_index is an expected argument of the model, you also need to add it to the explainer:

explanation = explainer(x=x0, edge_index=edge_index0, batch_index=..., target=None)

@sumone-compbio
Copy link

@rusty1s thank you so much it worked finally. However, it generates a slightly different graph every time I re-run the explanation.visualize_graph function. Could you let me know what this is? Also, instead of a directed graph is there a way to get an undirected graph so that later I can map the nodes to atoms to visualize molecules with functional groups or substructures important to the prediction? Thanks a ton, once again <3.

@sumone-compbio
Copy link

Hi @rusty1s, I have one more doubt, please. For my binary classification problem, I ran GNNExplainer as guided by you. I noticed that most edges have low edge_masks scores for the class predicted '0' graphs with a low probability score (<0.1 obtained by GCN model), there are some edges with high edge_masks scores but those are rare in this class.
My doubt is which of the following cases is true:
CASE 1: The edge_masks score is the same for both classes '0' and '1' i.e. a high edge_mask for an edge would contribute towards a graph predicted '1' and a low edge_mask would contribute towards a graph predicted '0'.
CASE 2: A high edge_mask for an edge in a graph predicted '0' would contribute towards the graph being '0' and its low edge_masks don't contribute much to the graph being '0'. Similarly, for the class '1', a high edge_mask would contribute towards the graph being '1' and its low edge_masks don't contribute much to the graph being '1'.

I hope you get the point of what I'm asking. Sorry for bothering you again but this is part of my thesis. I hope you're doing great. Thanks a lot, once again.

@rusty1s
Copy link
Member

rusty1s commented Jun 24, 2024

You would interpret it as case 2.

@sumone-compbio
Copy link

@rusty1s thank you, I was in doubt because a lot of my graph '0' has low edge_masks. Case 2 makes sense as the explainer also works on multiclass classification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants