Skip to content

Commit

Permalink
fixed dropout random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 8, 2019
1 parent 921aec9 commit e4d7eb6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/utils/test_dropout.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
from torch_geometric.utils import dropout_adj

torch.manual_seed(5)


def test_dropout_adj():
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])
Expand All @@ -12,10 +10,12 @@ def test_dropout_adj():
assert edge_index.tolist() == out[0].tolist()
assert edge_attr.tolist() == out[1].tolist()

torch.manual_seed(5)
out = dropout_adj(edge_index, edge_attr)
assert out[0].tolist() == [[0, 1, 2, 3], [1, 2, 1, 2]]
assert out[1].tolist() == [1, 3, 4, 6]
assert out[0].tolist() == [[1, 3], [0, 2]]
assert out[1].tolist() == [2, 6]

torch.manual_seed(5)
out = dropout_adj(edge_index, edge_attr, force_undirected=True)
assert out[0].tolist() == [[2, 3], [3, 2]]
assert out[1].tolist() == [5, 5]
assert out[0].tolist() == [[1, 2], [2, 1]]
assert out[1].tolist() == [3, 3]

0 comments on commit e4d7eb6

Please sign in to comment.