Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing Subgraph Methods #6613

Merged
merged 13 commits into from
Feb 7, 2023
Merged

Fixing Subgraph Methods #6613

merged 13 commits into from
Feb 7, 2023

Conversation

toenshoff
Copy link
Contributor

This pull request addresses a bug in both Data.subgraph and HeteroData.subgraph, which yield incorrect results when an unsorted list of integers is used to select a subgraph. In this case, the attributes of nodes are reordered according to the given list of integers, but the ids used in edge_index are not. Therefore, the node features will be distributed incorrectly in the output subgraph.
Note that Data.subgraph does use the utils.subgraph method with the relabel_nodes=True option to compute the reduced edge list. However, this is insufficient to ensure a correct output, as this only shifts node ids downward to ensure consecutive ids. It has no effect on the actual order of nodes, which causes this problem.

Consider the following example:

import torch
from torch_geometric.data import Data

data = Data(x=torch.tensor([0, 0, 1, 1]), edge_index=torch.tensor([[0, 2], [1, 3]]))
sub = data.subgraph(torch.tensor([0, 2, 1]))
print(sub.x)
print(sub.edge_index)

With torch_geometric==2.3.0 (current nightly) we obtain the following output:

x: [0 1 0]
Edge Idx:
 [[0]
 [1]]

Note that in the original graph each edge has a source and destination with the same node label. In the subgraph, the remaining edge has a source with label 0 and a destination with label 1. This is therefore not a valid subgraph.

I have addressed this bug by re-mapping the node ids in edge_index to their correct new index if the input is indeed an unsorted integer list. HeteroData.subgraph had the same issue with an analog solution. I also extended the tests to cover these issues.

Beyond this, I was wondering what the behaviour of Data.subgraph should be if the argument is a list of integers with duplicate entries. Should the method raise an Exception, ignore duplicates or create duplicate vertices? I have not implemented anything in this regard yet, but I do think this should be specified.

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching and fixing this.

Comment on lines 602 to 604
if (subset[:-1] > subset[1:]).any():
node_idx = torch.argsort(subset)
edge_index = node_idx[edge_index]
Copy link
Member

@wsad1 wsad1 Feb 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would simply calling torch.unique at the start of the function resolve this issue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fix the bug, too.

The solution here has the additional utility of allowing a user to change the order of vertices to a new order of their choice. If this is not considered useful for these methods, we could also simply switch to always sorting all integer lists.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users can change the order outside of this function if needed. So lets switch to sorting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. I have switched to using torch.unique at the beginning of both methods. This also addresses my question regarding handling duplicate indices.

Perhaps we should add a reorder_nodes method as a utility function? But this is a separate PR I guess.

torch_geometric/data/hetero_data.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Feb 6, 2023

Codecov Report

Merging #6613 (5a450e9) into master (a730437) will increase coverage by 0.00%.
The diff coverage is 100.00%.

❗ Current head 5a450e9 differs from pull request most recent head 035dcd3. Consider uploading reports for the commit 035dcd3 to get more accurate results

@@           Coverage Diff           @@
##           master    #6613   +/-   ##
=======================================
  Coverage   87.55%   87.56%           
=======================================
  Files         422      422           
  Lines       22879    22882    +3     
=======================================
+ Hits        20032    20036    +4     
+ Misses       2847     2846    -1     
Impacted Files Coverage Δ
torch_geometric/data/data.py 88.86% <100.00%> (+0.02%) ⬆️
torch_geometric/data/hetero_data.py 90.63% <100.00%> (+0.29%) ⬆️
torch_geometric/data/graph_store.py 90.40% <0.00%> (-0.16%) ⬇️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@wsad1 wsad1 enabled auto-merge (squash) February 7, 2023 11:21
@wsad1 wsad1 merged commit 3afd90e into pyg-team:master Feb 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants