Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MessagePassing: Support for min and mul aggregation #4219

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class MyConv(MessagePassing):
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int):
super().__init__(aggr='add')
out_channels: int, aggr: str = 'add'):
super().__init__(aggr=aggr)

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
Expand Down Expand Up @@ -107,6 +107,17 @@ def test_my_conv():
jit.fuse = True


@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul'])
def test_my_conv_aggr(aggr):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_weight = torch.randn(edge_index.size(1))

conv = MyConv(8, 32, aggr=aggr)
out = conv(x, edge_index, edge_weight)
assert out.size() == (4, 32)


def test_my_static_graph_conv():
x1 = torch.randn(3, 4, 8)
x2 = torch.randn(3, 2, 16)
Expand Down
16 changes: 8 additions & 8 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ class MessagePassing(torch.nn.Module):
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),

where :math:`\square` denotes a differentiable, permutation invariant
function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
MLPs.
function, *e.g.*, sum, mean, min, max or mul, and
:math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote
differentiable functions such as MLPs.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_gnn.html>`__ for the accompanying tutorial.

Args:
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"` or :obj:`None`).
(default: :obj:`"add"`)
(:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"mul"` or :obj:`None`). (default: :obj:`"add"`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, aggr: Optional[str] = "add",
super().__init__()

self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None]
assert self.aggr in ['add', 'sum', 'mean', 'min', 'max', 'mul', None]

self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
Expand Down Expand Up @@ -417,8 +417,8 @@ def aggregate(self, inputs: Tensor, index: Tensor,
argument which was initially passed to :meth:`propagate`.

By default, this function will delegate its call to scatter functions
that support "add", "mean" and "max" operations as specified in
:meth:`__init__` by the :obj:`aggr` argument.
that support "add", "mean", "min", "max" and "mul" operations as
specified in :meth:`__init__` by the :obj:`aggr` argument.
"""
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
Expand Down