Skip to content

Commit

Permalink
message passing min+mul (#4219)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Mar 8, 2022
1 parent 258aeae commit ecd5ebb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
15 changes: 13 additions & 2 deletions test/nn/conv/test_message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

class MyConv(MessagePassing):
def __init__(self, in_channels: Union[int, Tuple[int, int]],
out_channels: int):
super().__init__(aggr='add')
out_channels: int, aggr: str = 'add'):
super().__init__(aggr=aggr)

if isinstance(in_channels, int):
in_channels = (in_channels, in_channels)
Expand Down Expand Up @@ -107,6 +107,17 @@ def test_my_conv():
jit.fuse = True


@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul'])
def test_my_conv_aggr(aggr):
x = torch.randn(4, 8)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
edge_weight = torch.randn(edge_index.size(1))

conv = MyConv(8, 32, aggr=aggr)
out = conv(x, edge_index, edge_weight)
assert out.size() == (4, 32)


def test_my_static_graph_conv():
x1 = torch.randn(3, 4, 8)
x2 = torch.randn(3, 2, 16)
Expand Down
16 changes: 8 additions & 8 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ class MessagePassing(torch.nn.Module):
\left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right),
where :math:`\square` denotes a differentiable, permutation invariant
function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
MLPs.
function, *e.g.*, sum, mean, min, max or mul, and
:math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote
differentiable functions such as MLPs.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_gnn.html>`__ for the accompanying tutorial.
Args:
aggr (string, optional): The aggregation scheme to use
(:obj:`"add"`, :obj:`"mean"`, :obj:`"max"` or :obj:`None`).
(default: :obj:`"add"`)
(:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`,
:obj:`"mul"` or :obj:`None`). (default: :obj:`"add"`)
flow (string, optional): The flow direction of message passing
(:obj:`"source_to_target"` or :obj:`"target_to_source"`).
(default: :obj:`"source_to_target"`)
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(self, aggr: Optional[str] = "add",
super().__init__()

self.aggr = aggr
assert self.aggr in ['add', 'mean', 'max', None]
assert self.aggr in ['add', 'sum', 'mean', 'min', 'max', 'mul', None]

self.flow = flow
assert self.flow in ['source_to_target', 'target_to_source']
Expand Down Expand Up @@ -417,8 +417,8 @@ def aggregate(self, inputs: Tensor, index: Tensor,
argument which was initially passed to :meth:`propagate`.
By default, this function will delegate its call to scatter functions
that support "add", "mean" and "max" operations as specified in
:meth:`__init__` by the :obj:`aggr` argument.
that support "add", "mean", "min", "max" and "mul" operations as
specified in :meth:`__init__` by the :obj:`aggr` argument.
"""
if ptr is not None:
ptr = expand_left(ptr, dim=self.node_dim, dims=inputs.dim())
Expand Down

0 comments on commit ecd5ebb

Please sign in to comment.