Skip to content

Commit

Permalink
Merge pull request #486 from fdiehl/fix_globalpool_tests
Browse files Browse the repository at this point in the history
test_permuted_global_pool: Fixed seldom-occurring (~1% of cases) assert failure
  • Loading branch information
rusty1s committed Jul 5, 2019
2 parents 14b628e + 818291e commit 89afdb9
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions test/nn/glob/test_glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,22 @@ def test_permuted_global_pool():
batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)
perm = torch.randperm(N_1 + N_2)

out_1 = global_add_pool(x, batch)
out_2 = global_add_pool(x[perm], batch[perm])
assert torch.allclose(out_1, out_2)
px = x[perm]
pbatch = batch[perm]
px1 = px[pbatch == 0]
px2 = px[pbatch == 1]

out_1 = global_mean_pool(x, batch)
out_2 = global_mean_pool(x[perm], batch[perm])
assert torch.allclose(out_1, out_2)
out = global_add_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.sum(dim=0))
assert torch.allclose(out[1], px2.sum(dim=0))

out = global_mean_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.mean(dim=0))
assert torch.allclose(out[1], px2.mean(dim=0))

out_1 = global_max_pool(x, batch)
out_2 = global_max_pool(x[perm], batch[perm])
assert torch.allclose(out_1, out_2)
out = global_max_pool(px, pbatch)
assert out.size() == (2, 4)
assert torch.allclose(out[0], px1.max(dim=0)[0])
assert torch.allclose(out[1], px2.max(dim=0)[0])

0 comments on commit 89afdb9

Please sign in to comment.