Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 25, 2024
1 parent bfdd2ee commit 5c4aa64
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 17 deletions.
22 changes: 12 additions & 10 deletions test/loader/test_neighbor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ def test_homo_neighbor_loader_basic(
assert batch.n_id.size() == (batch.num_nodes, )
assert batch.input_id.numel() == batch.batch_size == 20
assert batch.x.min() >= 0 and batch.x.max() < 100
assert isinstance(batch.edge_index, EdgeIndex)
batch.edge_index.validate()
size = (batch.num_nodes, batch.num_nodes)
assert batch.edge_index.sparse_size() == size
assert batch.edge_index.sort_order == 'col'
# TODO Re-enable once `EdgeIndex` is stable.
assert not isinstance(batch.edge_index, EdgeIndex)
# batch.edge_index.validate()
# size = (batch.num_nodes, batch.num_nodes)
# assert batch.edge_index.sparse_size() == size
# assert batch.edge_index.sort_order == 'col'
assert batch.edge_index.device == device
assert batch.edge_index.min() >= 0
assert batch.edge_index.max() < batch.num_nodes
Expand Down Expand Up @@ -224,11 +225,12 @@ def test_hetero_neighbor_loader_basic(subgraph_type, dtype):

for edge_type, edge_index in batch.edge_index_dict.items():
src, _, dst = edge_type
assert isinstance(edge_index, EdgeIndex)
edge_index.validate()
size = (batch[src].num_nodes, batch[dst].num_nodes)
assert edge_index.sparse_size() == size
assert edge_index.sort_order == 'col'
# TODO Re-enable once `EdgeIndex` is stable.
assert not isinstance(edge_index, EdgeIndex)
# edge_index.validate()
# size = (batch[src].num_nodes, batch[dst].num_nodes)
# assert edge_index.sparse_size() == size
# assert edge_index.sort_order == 'col'

row, col = batch['paper', 'paper'].edge_index
assert row.min() >= 0 and row.max() < batch['paper'].num_nodes
Expand Down
15 changes: 8 additions & 7 deletions torch_geometric/loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch import Tensor

import torch_geometric.typing
from torch_geometric import EdgeIndex
from torch_geometric.data import (
Data,
FeatureStore,
Expand Down Expand Up @@ -105,13 +104,15 @@ def filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,
# which represents the new graph as denoted by `(row, col)`:
for key, value in store.items():
if key == 'edge_index':
edge_index = torch.stack([row, col], dim=0).to(value.device)
# TODO Integrate `EdgeIndex` into `custom_store`.
out_store.edge_index = EdgeIndex(
torch.stack([row, col], dim=0).to(value.device),
sparse_size=out_store.size(),
sort_order='col',
# TODO Support `is_undirected`.
)
# edge_index = EdgeIndex(
# torch.stack([row, col], dim=0).to(value.device),
# sparse_size=out_store.size(),
# sort_order='col',
# # TODO Support `is_undirected`.
# )
out_store.edge_index = edge_index

elif key == 'adj_t':
# NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).
Expand Down

0 comments on commit 5c4aa64

Please sign in to comment.