Skip to content

Commit

Permalink
Added variance-preserving Aggregation (#9075)
Browse files Browse the repository at this point in the history
A novel aggregation function as described in
[GNN-VPA](https://arxiv.org/pdf/2403.04747.pdf)

---------

Co-authored-by: Richard Freinschlag <freinschlag@ml.jku.at>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
4 people committed Mar 25, 2024
1 parent 8c070ad commit 8a9ace7
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075))
- Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073))
- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))
- Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026))
Expand Down
28 changes: 28 additions & 0 deletions test/nn/aggr/test_variance_preserving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from torch_geometric.nn import (
MeanAggregation,
SumAggregation,
VariancePreservingAggregation,
)


def test_variance_preserving():
x = torch.randn(6, 16)
index = torch.tensor([0, 0, 1, 1, 1, 3])
ptr = torch.tensor([0, 2, 5, 5, 6])

vpa_aggr = VariancePreservingAggregation()
mean_aggr = MeanAggregation()
sum_aggr = SumAggregation()

out_vpa = vpa_aggr(x, index)
out_mean = mean_aggr(x, index)
out_sum = sum_aggr(x, index)

# Equivalent formulation:
expected = torch.sqrt(out_mean.abs() * out_sum.abs()) * out_sum.sign()

assert out_vpa.size() == (4, 16)
assert torch.allclose(out_vpa, expected)
assert torch.allclose(out_vpa, vpa_aggr(x, ptr=ptr))
2 changes: 2 additions & 0 deletions torch_geometric/nn/aggr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .deep_sets import DeepSetsAggregation
from .set_transformer import SetTransformerAggregation
from .lcm import LCMAggregation
from .variance_preserving import VariancePreservingAggregation

__all__ = classes = [
'Aggregation',
Expand Down Expand Up @@ -51,4 +52,5 @@
'DeepSetsAggregation',
'SetTransformerAggregation',
'LCMAggregation',
'VariancePreservingAggregation',
]
33 changes: 33 additions & 0 deletions torch_geometric/nn/aggr/variance_preserving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from torch import Tensor

from torch_geometric.nn.aggr import Aggregation
from torch_geometric.utils import degree
from torch_geometric.utils._scatter import broadcast


class VariancePreservingAggregation(Aggregation):
r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA:
A Variance-Preserving Aggregation Strategy for Graph Neural Networks"
<https://arxiv.org/abs/2403.04747>`_ paper.
.. math::
\mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}}
\sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i
"""
def forward(self, x: Tensor, index: Optional[Tensor] = None,
ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
dim: int = -2) -> Tensor:

out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum')

if ptr is not None:
count = ptr.diff().to(out.dtype)
else:
count = degree(index, dim_size, dtype=out.dtype)

count = count.sqrt().clamp(min=1.0)
count = broadcast(count, ref=out, dim=dim)

return out / count
1 change: 1 addition & 0 deletions torch_geometric/utils/_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def scatter(


def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:
dim = ref.dim() + dim if dim < 0 else dim
size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1))
return src.view(size).expand_as(ref)

Expand Down

0 comments on commit 8a9ace7

Please sign in to comment.