Skip to content

Commit

Permalink
Add temporal leakage test (#9267)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 2, 2024
1 parent 36bc925 commit ab8f3fd
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,3 +966,32 @@ def test_neighbor_loader_input_id():
expected = [(2 * i) + 1]

assert batch['a'].input_id.tolist() == expected


@withPackage('pyg_lib')
def test_temporal_neighbor_loader_single_link():
data = HeteroData()
data['a'].x = torch.arange(10)
data['b'].x = torch.arange(10)
data['c'].x = torch.arange(10)

data['b'].time = torch.arange(0, 10)
data['c'].time = torch.arange(1, 11)

data['a', 'b'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)
data['b', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)
data['a', 'c'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)
data['c', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)

loader = NeighborLoader(
data,
num_neighbors=[-1],
input_nodes='a',
time_attr='time',
input_time=torch.arange(0, 10),
batch_size=10,
)
batch = next(iter(loader))
assert batch['a'].num_nodes == 10
assert batch['b'].num_nodes == 10
assert batch['c'].num_nodes == 0

0 comments on commit ab8f3fd

Please sign in to comment.