Skip to content
Permalink
Browse files

fixes failure modes of from_networkx (#830)

  • Loading branch information
rusty1s committed Jan 6, 2020
1 parent 696ed25 commit c0458f0c0fb84c27570d91c6a971e29109649fca
Showing with 63 additions and 24 deletions.
  1. +52 −14 test/utils/test_convert.py
  2. +11 −10 torch_geometric/utils/convert.py
@@ -3,8 +3,9 @@
import networkx as nx
from torch_sparse import coalesce
from torch_geometric.data import Data
from torch_geometric.utils import to_scipy_sparse_matrix, to_networkx
from torch_geometric.utils import from_scipy_sparse_matrix, from_networkx
from torch_geometric.utils import (to_scipy_sparse_matrix,
from_scipy_sparse_matrix)
from torch_geometric.utils import to_networkx, from_networkx
from torch_geometric.utils import to_trimesh, from_trimesh
from torch_geometric.utils import subgraph

@@ -69,18 +70,7 @@ def test_from_networkx():
assert edge_attr.tolist() == [3, 1, 2]


def test_subgraph_convert():
G = nx.complete_graph(5)

edge_index = from_networkx(G).edge_index
sub_edge_index_1, _ = subgraph([0, 1, 3, 4], edge_index)

sub_edge_index_2 = from_networkx(G.subgraph([0, 1, 3, 4])).edge_index

assert sub_edge_index_1.tolist() == sub_edge_index_2.tolist()


def test_vice_versa_convert():
def test_networkx_vice_versa_convert():
G = nx.complete_graph(5)
assert G.is_directed() is False
data = from_networkx(G)
@@ -91,6 +81,54 @@ def test_vice_versa_convert():
assert G.is_directed() is False


def test_from_networkx_non_consecutive():
graph = nx.Graph()
graph.add_node(4)
graph.add_node(2)
graph.add_edge(4, 2)
for node in graph.nodes():
graph.nodes[node]['x'] = node

data = from_networkx(graph)
assert len(data) == 2
assert data.x.tolist() == [4, 2]
assert data.edge_index.tolist() == [[0, 1], [1, 0]]


def test_from_networkx_non_numeric_labels():
graph = nx.Graph()
graph.add_node('4')
graph.add_node('2')
graph.add_edge('4', '2')
for node in graph.nodes():
graph.nodes[node]['x'] = node
data = from_networkx(graph)
assert len(data) == 2
assert data.x == ['4', '2']
assert data.edge_index.tolist() == [[0, 1], [1, 0]]


def test_from_networkx_without_edges():
graph = nx.Graph()
graph.add_node(1)
graph.add_node(2)
data = from_networkx(graph)
assert len(data) == 1
assert data.edge_index.size() == (2, 0)


def test_subgraph_convert():
G = nx.complete_graph(5)

edge_index = from_networkx(G).edge_index
sub_edge_index_1, _ = subgraph([0, 1, 3, 4], edge_index,
relabel_nodes=True)

sub_edge_index_2 = from_networkx(G.subgraph([0, 1, 3, 4])).edge_index

assert sub_edge_index_1.tolist() == sub_edge_index_2.tolist()


def test_trimesh():
pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],
dtype=torch.float)
@@ -86,26 +86,27 @@ def from_networkx(G):
G (networkx.Graph or networkx.DiGraph): A networkx graph.
"""

G = nx.convert_node_labels_to_integers(G)
G = G.to_directed() if not nx.is_directed(G) else G
edge_index = torch.tensor(list(G.edges)).t().contiguous()

keys = []
keys += list(list(G.nodes(data=True))[0][1].keys())
keys += list(list(G.edges(data=True))[0][2].keys())
data = {key: [] for key in keys}
data = {}

for _, feat_dict in G.nodes(data=True):
for i, (_, feat_dict) in enumerate(G.nodes(data=True)):
for key, value in feat_dict.items():
data[key].append(value)
data[key] = [value] if i == 0 else data[key] + [value]

for _, _, feat_dict in G.edges(data=True):
for i, (_, _, feat_dict) in enumerate(G.edges(data=True)):
for key, value in feat_dict.items():
data[key].append(value)
data[key] = [value] if i == 0 else data[key] + [value]

for key, item in data.items():
data[key] = torch.tensor(item)
try:
data[key] = torch.tensor(item)
except ValueError:
pass

data['edge_index'] = edge_index
data['edge_index'] = edge_index.view(2, -1)
data = torch_geometric.data.Data.from_dict(data)
data.num_nodes = G.number_of_nodes()

0 comments on commit c0458f0

Please sign in to comment.
You can’t perform that action at this time.