Skip to content

Commit

Permalink
Merge pull request #483 from fdiehl/globalpool_batch_order
Browse files Browse the repository at this point in the history
Allowing for permuted batch references for global pool operations
  • Loading branch information
rusty1s committed Jul 4, 2019
2 parents a076bc2 + aa5e34d commit 122745d
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
26 changes: 26 additions & 0 deletions test/nn/glob/test_glob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import pytest
from torch_geometric.nn import (global_add_pool, global_mean_pool,
global_max_pool)

Expand All @@ -22,3 +23,28 @@ def test_global_pool():
assert out.size() == (2, 4)
assert out[0].tolist() == x[:4].max(dim=0)[0].tolist()
assert out[1].tolist() == x[4:].max(dim=0)[0].tolist()


def test_permuted_global_pool():
N_1, N_2 = 4, 6
x = torch.randn(N_1 + N_2, 4)
batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])

permutation = torch.randperm(N_1 + N_2)
perm_x = x[permutation]
perm_batch = batch[permutation]

out = global_add_pool(perm_x, perm_batch)
assert out.size() == (2, 4)
assert out[0].tolist() == pytest.approx(x[:4].sum(dim=0).tolist())
assert out[1].tolist() == pytest.approx(x[4:].sum(dim=0).tolist())

out = global_mean_pool(perm_x, perm_batch)
assert out.size() == (2, 4)
assert out[0].tolist() == pytest.approx(x[:4].mean(dim=0).tolist())
assert out[1].tolist() == pytest.approx(x[4:].mean(dim=0).tolist())

out = global_max_pool(perm_x, perm_batch)
assert out.size() == (2, 4)
assert out[0].tolist() == pytest.approx(x[:4].max(dim=0)[0].tolist())
assert out[1].tolist() == pytest.approx(x[4:].max(dim=0)[0].tolist())
7 changes: 4 additions & 3 deletions torch_geometric/nn/glob/glob.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from torch_geometric.utils import scatter_


Expand All @@ -20,7 +21,7 @@ def global_add_pool(x, batch, size=None):
:rtype: :class:`Tensor`
"""

size = batch[-1].item() + 1 if size is None else size
size = torch.max(batch) + 1 if size is None else size
return scatter_('add', x, batch, dim_size=size)


Expand All @@ -43,7 +44,7 @@ def global_mean_pool(x, batch, size=None):
:rtype: :class:`Tensor`
"""

size = batch[-1].item() + 1 if size is None else size
size = torch.max(batch) + 1 if size is None else size
return scatter_('mean', x, batch, dim_size=size)


Expand All @@ -66,5 +67,5 @@ def global_max_pool(x, batch, size=None):
:rtype: :class:`Tensor`
"""

size = batch[-1].item() + 1 if size is None else size
size = torch.max(batch) + 1 if size is None else size
return scatter_('max', x, batch, dim_size=size)

0 comments on commit 122745d

Please sign in to comment.