Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain: Add PGExplainer #6204

Merged
merged 46 commits into from
Dec 27, 2022
Merged

Explain: Add PGExplainer #6204

merged 46 commits into from
Dec 27, 2022

Conversation

wsad1
Copy link
Member

@wsad1 wsad1 commented Dec 10, 2022

This PR makes the following changes.

  1. Adds PGExplainer a new explainer algorithm. Paper , Reference Implementation
  2. Updated the Explainer and ExplainerAlgorithm interface to support PGExplainer. Since PGExplainer uses a mlp to explain predictions that mlp needs to be trained before we call Explainer.forward.
    The usage of PGExplainer would look like:
  explainer = Explainer(
      model=model,
      algorithm=PGExplainer(epochs=2),
      explanation_type=explanation_type,
      node_mask_type=None, # pg explainer only generates edge masks
      edge_mask_type=edge_mask_type,
      model_config=model_config,)
explainer.train_explainer_algorithm(x, edge_index , target)
# Forward with throw an error if 'train_explainer_algorithm' is not called
explanation = explainer.forward(x...)

Open to feedback on updates to the interface.

TODOs

  1. Add example in follow up PR.

@codecov
Copy link

codecov bot commented Dec 10, 2022

Codecov Report

Merging #6204 (90ec896) into master (eb93ec0) will increase coverage by 0.03%.
The diff coverage is 91.20%.

❗ Current head 90ec896 differs from pull request most recent head 4688043. Consider uploading reports for the commit 4688043 to get more accurate results

@@            Coverage Diff             @@
##           master    #6204      +/-   ##
==========================================
+ Coverage   84.56%   84.59%   +0.03%     
==========================================
  Files         380      381       +1     
  Lines       21162    21264     +102     
==========================================
+ Hits        17896    17989      +93     
- Misses       3266     3275       +9     
Impacted Files Coverage Δ
torch_geometric/explain/algorithm/base.py 93.82% <90.90%> (-1.18%) ⬇️
torch_geometric/explain/algorithm/pg_explainer.py 91.08% <91.08%> (ø)
torch_geometric/explain/algorithm/__init__.py 100.00% <100.00%> (ø)
torch_geometric/explain/algorithm/gnn_explainer.py 97.65% <100.00%> (+1.01%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@wsad1 wsad1 changed the base branch from master to get_mp December 10, 2022 14:15
@rusty1s rusty1s changed the title Explain: Add PGExplainer. Explain: Add PGExplainer Dec 12, 2022
@github-actions github-actions bot removed the utils label Dec 15, 2022
Base automatically changed from get_mp to master December 15, 2022 10:44
@rusty1s rusty1s mentioned this pull request Dec 20, 2022
@github-actions github-actions bot removed the utils label Dec 23, 2022
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @wsad1. I made some slight change by moving the training loop outside of the explainer. I feel this allows the user more control, as the model might be training on a variety of graphs. Hope the changes are okay. Feel free to revise them if you feel they don't make sense.

@rusty1s rusty1s enabled auto-merge (squash) December 27, 2022 16:45
@rusty1s rusty1s merged commit 81723f4 into master Dec 27, 2022
@rusty1s rusty1s deleted the explainer_pg branch December 27, 2022 16:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants