Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Explainability] binary_classification mode + link prediction example #6083

Merged
merged 25 commits into from
Dec 9, 2022

Conversation

camillepradel
Copy link
Contributor

@camillepradel camillepradel commented Nov 28, 2022

Progress towards #5924

Done:

  • adds support for binary classification mode by splitting ModelMode.classification into ModelMode.binary_classification and ModelMode.multiclass_classification
  • adds an example script gnn_explainer_link_pred.py training a link prediction model and getting the explanation for one output.

@codecov
Copy link

codecov bot commented Nov 28, 2022

Codecov Report

Merging #6083 (a553fb7) into master (f343295) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head a553fb7 differs from pull request most recent head b01f4af. Consider uploading reports for the commit b01f4af to get more accurate results

@@           Coverage Diff           @@
##           master    #6083   +/-   ##
=======================================
  Coverage   84.48%   84.49%           
=======================================
  Files         371      371           
  Lines       20725    20738   +13     
=======================================
+ Hits        17509    17522   +13     
  Misses       3216     3216           
Impacted Files Coverage Δ
torch_geometric/explain/algorithm/base.py 96.61% <ø> (-0.36%) ⬇️
torch_geometric/explain/algorithm/gnn_explainer.py 96.29% <100.00%> (+0.32%) ⬆️
torch_geometric/explain/config.py 100.00% <100.00%> (ø)
torch_geometric/explain/explainer.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@avgupta456
Copy link
Contributor

I wonder if we could combine this PR and #6056. This PR contributes an example script while #6056 handles the k_hop_subgraph and adds edge-level tests for GNNExplainer.

@camillepradel
Copy link
Contributor Author

I wonder if we could combine this PR and #6056. This PR contributes an example script while #6056 handles the k_hop_subgraph and adds edge-level tests for GNNExplainer.

oops, I didn't see your PR! I will have a look, thanks!

@rusty1s
Copy link
Member

rusty1s commented Nov 29, 2022

@camillepradel Any reason to close this? We would still like to integrate an example of this :)

@camillepradel
Copy link
Contributor Author

you are right. I wanted to open a new PR but there is actually no good reason to not keep using this one.

@camillepradel camillepradel reopened this Nov 30, 2022
@camillepradel camillepradel changed the title Link prediction explanation Link prediction explanation example Nov 30, 2022
@camillepradel camillepradel changed the title Link prediction explanation example [Explainability] binary_classification mode + link prediction example Dec 2, 2022
@camillepradel
Copy link
Contributor Author

So I updated the example, but since it was a link prediction task, I needed the explainability framework to support binary classification. I ended up adding a new binary_classification mode (and renaming the original classification to multiclass_classification).
If somebody can think of a simpler way to support binary classification, I am curious to know about it.

@camillepradel camillepradel marked this pull request as ready for review December 2, 2022 19:05
Copy link
Contributor

@RBendias RBendias left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks a lot for the PR! I left some comments in the example file. Regarding the binary_classification type, I also think we should find a simpler solution, e.g., by selecting the loss based on the models' output. I'll take a closer look.

examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
examples/gnn_explainer_link_pred.py Outdated Show resolved Hide resolved
@RBendias
Copy link
Contributor

RBendias commented Dec 5, 2022

I checked the code again. The main reason I see for the additional type is to check if the return_type is set correctly in the ModelConfig. The return_type needs to be probs, as we use binary_cross_entropy. However, we could also measure the MSE between the raw/probs/log_probs values (also using the raw model output in get_prediction instead of setting a threshold of 0.5). @camillepradel What do you think?

@camillepradel
Copy link
Contributor Author

I checked the code again. The main reason I see for the additional type is to check if the return_type is set correctly in the ModelConfig. The return_type needs to be probs, as we use binary_cross_entropy. However, we could also measure the MSE between the raw/probs/log_probs values (also using the raw model output in get_prediction instead of setting a threshold of 0.5). @camillepradel What do you think?

The distinction between multiclass_classification and binary_classification modes is indeed used in ModelConfig, but also in processing prediction in Explainer and in processing loss in GNNExplainer. In the two later cases, it allows to handle differently the output of the model according to the mode. I initially though we could also make the difference by looking at the shape of the output, but it looked tricky to me (depending on the number of classes and the optional batching, we might not be able to know), which is why I went to define explicitly two distinct modes.

You are right, we don't have to enforce return_type to be probs for bynary classification (I have never seen a log_probs output in that setup but it seems legit), I will change that.

@camillepradel
Copy link
Contributor Author

I updated the code to allow raw and log_probs modes, but I am still a bit confused about log_probs.

More specifically, I don't know how to apply MSE loss directly on log_probs since the range of values for log probabilities is (−∞,0]. In current version, I applied binary_cross_entropy to the exp() of y_hat.

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few nit-picky comments. I think we need to drop log_probs support in binary_classification.

torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/gnn_explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/gnn_explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/gnn_explainer.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_gnn_explainer.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_gnn_explainer.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_gnn_explainer.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_gnn_explainer.py Show resolved Hide resolved
test/explain/algorithm/test_gnn_explainer.py Outdated Show resolved Hide resolved
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. Looks great!

@rusty1s rusty1s enabled auto-merge (squash) December 9, 2022 13:30
@rusty1s rusty1s merged commit 737cc76 into pyg-team:master Dec 9, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants