Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jan 3, 2022
1 parent 78c6e10 commit 7014c50
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions test/transforms/test_random_link_split.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch

from torch_geometric.utils import is_undirected
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import is_undirected, to_undirected


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
Expand Down Expand Up @@ -92,7 +92,8 @@ def test_random_link_split_on_hetero_data():
data['a'].x = torch.arange(100, 300)

data['p', 'p'].edge_index = get_edge_index(100, 100, 500)
data['p', 'p'].edge_attr = torch.arange(500)
data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index)
data['p', 'p'].edge_attr = torch.arange(data['p', 'p'].num_edges)
data['p', 'a'].edge_index = get_edge_index(100, 200, 1000)
data['p', 'a'].edge_attr = torch.arange(500, 1500)
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/utils/undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_undirected(
edge_attr = [] if edge_attr is None else edge_attr
edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr

edge_index1, edge_attr1 = coalesce(
edge_index1, edge_attr1 = sort_edge_index(
edge_index,
edge_attr,
num_nodes=num_nodes,
Expand Down

0 comments on commit 7014c50

Please sign in to comment.