Skip to content
Permalink
Browse files

fix graph saint test

  • Loading branch information
rusty1s committed Mar 24, 2020
1 parent 5824317 commit 970046c7bf8fd6d108a979308190fe87f7781b84
Showing with 8 additions and 17 deletions.
  1. +8 −17 test/data/test_graph_saint.py
@@ -1,12 +1,10 @@
from itertools import chain

import torch
from torch_geometric.data import (Data, GraphSAINTNodeSampler,
GraphSAINTEdgeSampler,
GraphSAINTRandomWalkSampler)


def test_cluster_gcn():
def test_graph_saint():
adj = torch.tensor([
[1, 1, 1, 0, 1, 0],
[1, 1, 0, 1, 0, 1],
@@ -20,35 +18,28 @@ def test_cluster_gcn():
x = torch.Tensor([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
data = Data(edge_index=edge_index, x=x, num_nodes=6)

loader1 = GraphSAINTNodeSampler(data, batch_size=2, num_steps=4, log=False)
loader2 = GraphSAINTNodeSampler(data, batch_size=2, num_steps=4,
num_workers=2, log=False)
loader = GraphSAINTNodeSampler(data, batch_size=2, num_steps=4, log=False)

for sample in chain(loader1, loader2):
for sample in loader:
assert len(sample) == 4
assert sample.num_nodes <= 2
assert sample.num_edges <= 3 * 2
assert sample.node_norm.numel() == sample.num_nodes
assert sample.edge_norm.numel() == sample.num_edges

loader1 = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, log=False)
loader2 = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4,
num_workers=2, log=False)
loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, log=False)

for sample in chain(loader1, loader2):
for sample in loader:
assert len(sample) == 4
assert sample.num_nodes <= 4
assert sample.num_edges <= 3 * 4
assert sample.node_norm.numel() == sample.num_nodes
assert sample.edge_norm.numel() == sample.num_edges

loader1 = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,
num_steps=4, log=False)
loader2 = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,
num_steps=4, num_workers=2,
log=False)
loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,
num_steps=4, log=False)

for sample in chain(loader1, loader2):
for sample in loader:
assert len(sample) == 4
assert sample.num_nodes <= 4
assert sample.num_edges <= 3 * 4

0 comments on commit 970046c

Please sign in to comment.
You can’t perform that action at this time.