Skip to content

Commit

Permalink
is_undirected without coalesce (#3789)
Browse files Browse the repository at this point in the history
* is_undirected without coalesce

* fix

* update
  • Loading branch information
rusty1s committed Jan 3, 2022
1 parent 1a5e65c commit 47dfb9c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
5 changes: 3 additions & 2 deletions test/transforms/test_random_link_split.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch

from torch_geometric.utils import is_undirected
from torch_geometric.data import Data, HeteroData
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.utils import is_undirected, to_undirected


def get_edge_index(num_src_nodes, num_dst_nodes, num_edges):
Expand Down Expand Up @@ -92,7 +92,8 @@ def test_random_link_split_on_hetero_data():
data['a'].x = torch.arange(100, 300)

data['p', 'p'].edge_index = get_edge_index(100, 100, 500)
data['p', 'p'].edge_attr = torch.arange(500)
data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index)
data['p', 'p'].edge_attr = torch.arange(data['p', 'p'].num_edges)
data['p', 'a'].edge_index = get_edge_index(100, 200, 1000)
data['p', 'a'].edge_attr = torch.arange(500, 1500)
data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])
Expand Down
20 changes: 14 additions & 6 deletions torch_geometric/utils/undirected.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch
from torch import Tensor
from torch_geometric.utils.coalesce import coalesce

from torch_geometric.utils import sort_edge_index, coalesce

from .num_nodes import maybe_num_nodes

Expand All @@ -26,16 +27,23 @@ def is_undirected(
:rtype: bool
"""

num_nodes = maybe_num_nodes(edge_index, num_nodes)

edge_attr = [] if edge_attr is None else edge_attr
edge_attr = [edge_attr] if isinstance(edge_attr, Tensor) else edge_attr

edge_index1, edge_attr1 = coalesce(edge_index, edge_attr,
num_nodes=num_nodes, sort_by_row=True)
edge_index2, edge_attr2 = coalesce(edge_index1, edge_attr1,
num_nodes=num_nodes, sort_by_row=False)
edge_index1, edge_attr1 = sort_edge_index(
edge_index,
edge_attr,
num_nodes=num_nodes,
sort_by_row=True,
)
edge_index2, edge_attr2 = sort_edge_index(
edge_index1,
edge_attr1,
num_nodes=num_nodes,
sort_by_row=False,
)

return (bool(torch.all(edge_index1[0] == edge_index2[1]))
and bool(torch.all(edge_index1[1] == edge_index2[0])) and all([
Expand Down

0 comments on commit 47dfb9c

Please sign in to comment.