-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement PEG from ICLR 2022 #4571
base: master
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
Codecov Report
@@ Coverage Diff @@
## master #4571 +/- ##
==========================================
+ Coverage 82.70% 82.83% +0.13%
==========================================
Files 329 330 +1
Lines 17844 17744 -100
==========================================
- Hits 14758 14699 -59
+ Misses 3086 3045 -41
Continue to review full report at Codecov.
|
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
torch_geometric/nn/conv/peg_conv.py
Outdated
**kwargs (optional): Additional arguments of | ||
:class:`torch_geometric.nn.conv.MessagePassing`. | ||
""" | ||
def __init__(self, in_feats_dim: int, pos_dim: int, out_feats_dim: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, in_feats_dim: int, pos_dim: int, out_feats_dim: int, | |
def __init__(self, in_channels: int, pos_dim: int, out_channels: int, |
torch.Tensor(in_feats_dim + in_feats_dim, out_feats_dim)) | ||
self.weight_noformer = Parameter( | ||
torch.Tensor(in_feats_dim, out_feats_dim)) | ||
self.edge_mlp = nn.Sequential(nn.Linear(1, edge_mlp_dim), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[newbie question] somewhat confused by this - If you go from 1 dim back to 1 dim are you learning much?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I can use 'information bottleneck' to explain it. Though I go from 1 dim back to 1 dim, the network rids noisy input data of extraneous details as if by squeezing the information through a bottleneck, retaining only the features most relevant to the edge weight.
torch_geometric/nn/conv/peg_conv.py
Outdated
# return tuple | ||
return self.update((hidden_out, coors_out), **update_kwargs) | ||
|
||
def glorot(self, tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
available in torch_geometric.nn.inits
torch_geometric/nn/conv/peg_conv.py
Outdated
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1))) | ||
tensor.data.uniform_(-stdv, stdv) | ||
|
||
def zeros(self, tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
torch_geometric/nn/conv/peg_conv.py
Outdated
r"""The initial call to start propagating messages. | ||
|
||
Args: | ||
`edge_index` holds the indices of a general (sparse) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
`edge_index` holds the indices of a general (sparse) | |
edge_index (Tensor): holds the indices of a general (sparse) |
torch_geometric/nn/conv/peg_conv.py
Outdated
Args: | ||
`edge_index` holds the indices of a general (sparse) | ||
assignment matrix of shape :obj:`[N, M]`. | ||
size (tuple, optional) if none, the size will be inferred |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
size (tuple, optional) if none, the size will be inferred | |
size (tuple, optional): if none, the size will be inferred |
torch_geometric/nn/conv/peg_conv.py
Outdated
return x_j if edge_weight is None else (PE_edge_weight * | ||
edge_weight.view(-1, 1) * x_j) | ||
|
||
def propagate(self, edge_index: Adj, size: Size = None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't fully follow the implementation yet - but I think you should be able to do this without needing to reimplement propagate
- I.e you can put most of this logic inside your forward function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 . We definitely do not want to override propagate
here. Any reason you did this in the first place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your advice. I put the operational logic inside my forward function.
@rusty1s Have just modified the implementation. Please review all the files and let me know if anything more is required. Thanks so much. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. This looks good. Please add your layer to the README.md
as well.
CHANGELOG.md
Outdated
@@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). | |||
|
|||
## [2.0.5] - 2022-MM-DD | |||
### Added | |||
- Added the `PEG` model ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Added the `PEG` model ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571)) | |
- Added the `PEGConv` layer from the ["Equivariant and Stable Positional Encoding for More Powerful Graph Neural Networks"](https://arxiv.org/abs/2203.00199>) paper ([#4571](https://github.com/pyg-team/pytorch_geometric/pull/4571)) |
torch_geometric/nn/conv/peg_conv.py
Outdated
for More Powerful Graph Neural Networks" | ||
<https://arxiv.org/abs/2203.00199>`_ paper. | ||
|
||
$$X^{'},Z^{'}=(\sigma [(\hat{A} \odot M)XW],Z)$$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we follow the .. math::
syntax here as in other layers? The formula also looks wrong: https://pytorch-geometric--4571.org.readthedocs.build/en/4571/modules/nn.html#torch_geometric.nn.conv.PEGConv
torch_geometric/nn/conv/peg_conv.py
Outdated
<https://arxiv.org/abs/2203.00199>`_ paper. | ||
|
||
$$X^{'},Z^{'}=(\sigma [(\hat{A} \odot M)XW],Z)$$ | ||
where\ $M_{uv}=MLP(||Z_u-Z_v||),\forall u,v \in V$. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here again. Use
:math:`M_{u,v} ...`
return x_j if edge_weight is None else (PE_edge_weight * | ||
edge_weight.view(-1, 1) * x_j) | ||
|
||
def __repr__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can be dropped.
torch_geometric/nn/conv/peg_conv.py
Outdated
def message(self, x_i: Tensor, x_j: Tensor, edge_weight: OptTensor, | ||
pos) -> Tensor: | ||
PE_edge_weight = self.edge_mlp(pos) | ||
return x_j if edge_weight is None else (PE_edge_weight * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we always use edge_mlp
here?
out = x_j * self.edge_mlp(pos)
return out if edge_weight is None else edge_weight.view(-1, 1) * out
Hi @ZoomWang666, thanks again for this PR. I would like to merge it once the open conversations are resolved. Can you take a look if you have time and let me know? Best |
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
PEG uses a GCN layer with edge weights according to the distance between the end nodes of the edge and keep the positional encodings unchanged. The positional encodings could to be calculated by some graph embedding techniques, such as DeepWalk and Laplacian eigenmap.
Examples and descriptions could be found here.