Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Heterogeneous Explanation #6091

Merged
merged 24 commits into from
Dec 10, 2022
Merged

Heterogeneous Explanation #6091

merged 24 commits into from
Dec 10, 2022

Conversation

avgupta456
Copy link
Contributor

@avgupta456 avgupta456 commented Nov 29, 2022

Starts heterogeneous explanations using the new Explainer framework. Creates HeterogeneousExplanation and extends Explainer, ExplainerAlgorithm, and DummyExplainer to handle heterogeneous graphs. Tests and functionalities are partially complete, looking to get some initial validation before continuing.

Todo

  • Heterogeneous explanation tests
  • Add and test thresholding for heterogeneous explanations
  • Finish some heterogeneous explanation methods
  • Extend GNNExplainer (maybe in separate PR)

Closes #6014

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, overall the interfaces look good. Left some initial comments.

torch_geometric/explain/algorithm/utils.py Outdated Show resolved Hide resolved
torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/base.py Show resolved Hide resolved
torch_geometric/explain/algorithm/base.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Dec 1, 2022

Codecov Report

Merging #6091 (bf44c90) into master (f82d6a2) will increase coverage by 0.05%.
The diff coverage is 96.90%.

@@            Coverage Diff             @@
##           master    #6091      +/-   ##
==========================================
+ Coverage   84.49%   84.55%   +0.05%     
==========================================
  Files         371      371              
  Lines       20741    20821      +80     
==========================================
+ Hits        17525    17605      +80     
  Misses       3216     3216              
Impacted Files Coverage Δ
torch_geometric/explain/algorithm/gnn_explainer.py 95.18% <60.00%> (-1.12%) ⬇️
...rch_geometric/explain/algorithm/dummy_explainer.py 93.75% <94.44%> (-0.37%) ⬇️
torch_geometric/data/hetero_data.py 95.65% <100.00%> (+0.68%) ⬆️
torch_geometric/explain/__init__.py 100.00% <100.00%> (ø)
torch_geometric/explain/algorithm/base.py 96.72% <100.00%> (+0.11%) ⬆️
torch_geometric/explain/explainer.py 100.00% <100.00%> (ø)
torch_geometric/explain/explanation.py 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@avgupta456
Copy link
Contributor Author

avgupta456 commented Dec 1, 2022

I added logic and testing for thresholding, validate() function in heterogeneous explanation, and fixed some bugs. The last item left for this PR is creating heterogeneous explanation subgraph/complement_subgraph using _apply_masks.

I am unsure if my implementation is good, overall having trouble extracting the x_dict, edge_index_dict, and edge_attrs from a HeterogeneousExplanation

Currently I am doing it like this

x_dict = None
edge_index_dict = None
edge_attr_dict = None
for key, value in self.node_items():
    if key == "x_dict":
        x_dict = value
    elif key == "edge_index_dict":
        edge_index_dict = value
    elif key == "edge_attr_dict":
        edge_attr_dict = value
        continue

And setting the updated dictionaries like this:

out.edge_index_dict.update(edge_index_dict)
out.edge_attr_dict.update(edge_attr_dict)

@wsad1 Is there a better way to access the HeterogeneousExplanation underlying graph?

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @avgupta456 . Will take a look again once you are done with TODOs in test.

torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/dummy_explainer.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_base.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_base.py Outdated Show resolved Hide resolved
test/explain/test_explainer_hetero.py Show resolved Hide resolved
@avgupta456
Copy link
Contributor Author

I'm currently running into an issue with HeteroExplanation that is blocking my further progress.

Unlike Explanation (inherits from Data), HeteroExplanation (inherits from HeteroData) does not correctly store the x_dict, edge_index_dict, and edge_attrs_dict when initialized with these arguments. It does work if the dictionaries are set as attributes afterwards.

HeteroExplanation(x_dict=x_dict, edge_index_dict=edge_index_dict, edge_attrs_dict=edge_attrs_dict)  # does not work

explanation = HeteroExplanation()  # does work
explanation.x_dict = x_dict
explanation.edge_index_dict = edge_index_dict
explanation.edge_attrs_dict = edge_attrs_dict

One workaround is that when initialized with the dictionaries, they are accessible through self.node_items(). Perhaps this is a bug with HeteroData? Or maybe I am doing something else incorrectly.

As a result of this bug, the subgraph method isn't working as intended. Open to any advice.

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks clean now. Will approve once the subgraph issue is resolved.

torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
Copy link
Member

@dufourc1 dufourc1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work ! LGTM, these are nitpick comments

torch_geometric/explain/explainer.py Show resolved Hide resolved
torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @avgupta456, this looks good.

torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explanation.py Outdated Show resolved Hide resolved
torch_geometric/explain/explainer.py Outdated Show resolved Hide resolved
torch_geometric/explain/algorithm/dummy_explainer.py Outdated Show resolved Hide resolved
torch_geometric/data/hetero_data.py Show resolved Hide resolved
test/explain/test_explainer_hetero.py Outdated Show resolved Hide resolved
test/explain/algorithm/test_base.py Outdated Show resolved Hide resolved
@wsad1 wsad1 merged commit 293203d into pyg-team:master Dec 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Explainability]: explain support for heterogenous graphs.
4 participants