Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement PEG from ICLR 2022 #4571

Open
wants to merge 52 commits into
base: master
Choose a base branch
from

Conversation

zoom-wang112358
Copy link

PEG uses a GCN layer with edge weights according to the distance between the end nodes of the edge and keep the positional encodings unchanged. The positional encodings could to be calculated by some graph embedding techniques, such as DeepWalk and Laplacian eigenmap.

Examples and descriptions could be found here.

@zoom-wang112358 zoom-wang112358 changed the title Implement PEG Implement PEG from ICLR 2022 Apr 30, 2022
@codecov
Copy link

codecov bot commented Apr 30, 2022

Codecov Report

Merging #4571 (ed0ce31) into master (e8b6def) will increase coverage by 0.13%.
The diff coverage is 100.00%.

❗ Current head ed0ce31 differs from pull request most recent head dc8c640. Consider uploading reports for the commit dc8c640 to get more accurate results

@@            Coverage Diff             @@
##           master    #4571      +/-   ##
==========================================
+ Coverage   82.70%   82.83%   +0.13%     
==========================================
  Files         329      330       +1     
  Lines       17844    17744     -100     
==========================================
- Hits        14758    14699      -59     
+ Misses       3086     3045      -41     
Impacted Files Coverage Δ
torch_geometric/nn/conv/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/peg_conv.py 100.00% <100.00%> (ø)
torch_geometric/data/in_memory_dataset.py 77.27% <0.00%> (-3.89%) ⬇️
torch_geometric/data/separate.py 97.77% <0.00%> (-2.23%) ⬇️
torch_geometric/data/storage.py 81.01% <0.00%> (-0.64%) ⬇️
torch_geometric/data/data.py 91.43% <0.00%> (-0.14%) ⬇️
torch_geometric/data/hetero_data.py 94.11% <0.00%> (-0.10%) ⬇️
torch_geometric/data/__init__.py 100.00% <0.00%> (ø)
torch_geometric/nn/dense/mincut_pool.py 100.00% <0.00%> (ø)
torch_geometric/nn/models/autoencoder.py 100.00% <0.00%> (ø)
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e8b6def...dc8c640. Read the comment docs.

**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_feats_dim: int, pos_dim: int, out_feats_dim: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(self, in_feats_dim: int, pos_dim: int, out_feats_dim: int,
def __init__(self, in_channels: int, pos_dim: int, out_channels: int,

torch.Tensor(in_feats_dim + in_feats_dim, out_feats_dim))
self.weight_noformer = Parameter(
torch.Tensor(in_feats_dim, out_feats_dim))
self.edge_mlp = nn.Sequential(nn.Linear(1, edge_mlp_dim),
Copy link
Contributor

@Padarn Padarn May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[newbie question] somewhat confused by this - If you go from 1 dim back to 1 dim are you learning much?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can use 'information bottleneck' to explain it. Though I go from 1 dim back to 1 dim, the network rids noisy input data of extraneous details as if by squeezing the information through a bottleneck, retaining only the features most relevant to the edge weight.

# return tuple
return self.update((hidden_out, coors_out), **update_kwargs)

def glorot(self, tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

available in torch_geometric.nn.inits

stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)

def zeros(self, tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

r"""The initial call to start propagating messages.

Args:
`edge_index` holds the indices of a general (sparse)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
`edge_index` holds the indices of a general (sparse)
edge_index (Tensor): holds the indices of a general (sparse)

Args:
`edge_index` holds the indices of a general (sparse)
assignment matrix of shape :obj:`[N, M]`.
size (tuple, optional) if none, the size will be inferred
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
size (tuple, optional) if none, the size will be inferred
size (tuple, optional): if none, the size will be inferred

return x_j if edge_weight is None else (PE_edge_weight *
edge_weight.view(-1, 1) * x_j)

def propagate(self, edge_index: Adj, size: Size = None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't fully follow the implementation yet - but I think you should be able to do this without needing to reimplement propagate - I.e you can put most of this logic inside your forward function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 . We definitely do not want to override propagate here. Any reason you did this in the first place?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your advice. I put the operational logic inside my forward function.

@zoom-wang112358
Copy link
Author

@rusty1s Have just modified the implementation. Please review all the files and let me know if anything more is required. Thanks so much.

Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. This looks good. Please add your layer to the README.md as well.

CHANGELOG.md Outdated
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added the `PEG` model ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- Added the `PEG` model ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571))
- Added the `PEGConv` layer from the ["Equivariant and Stable Positional Encoding for More Powerful Graph Neural Networks"](https://arxiv.org/abs/2203.00199>) paper ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571))

for More Powerful Graph Neural Networks"
<https://arxiv.org/abs/2203.00199>`_ paper.

$$X^{'},Z^{'}=(\sigma [(\hat{A} \odot M)XW],Z)$$
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we follow the .. math:: syntax here as in other layers? The formula also looks wrong: https://pytorch-geometric--4571.org.readthedocs.build/en/4571/modules/nn.html#torch_geometric.nn.conv.PEGConv

<https://arxiv.org/abs/2203.00199>`_ paper.

$$X^{'},Z^{'}=(\sigma [(\hat{A} \odot M)XW],Z)$$
where\ $M_{uv}=MLP(||Z_u-Z_v||),\forall u,v \in V$.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here again. Use

:math:`M_{u,v} ...`

return x_j if edge_weight is None else (PE_edge_weight *
edge_weight.view(-1, 1) * x_j)

def __repr__(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be dropped.

def message(self, x_i: Tensor, x_j: Tensor, edge_weight: OptTensor,
pos) -> Tensor:
PE_edge_weight = self.edge_mlp(pos)
return x_j if edge_weight is None else (PE_edge_weight *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we always use edge_mlp here?

out = x_j * self.edge_mlp(pos)
return out if edge_weight is None else edge_weight.view(-1, 1) * out

@rusty1s
Copy link
Member

rusty1s commented Jun 21, 2022

Hi @ZoomWang666, thanks again for this PR. I would like to merge it once the open conversations are resolved. Can you take a look if you have time and let me know? Best

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants