Skip to content
Permalink
Browse files

add trimesh conversions

  • Loading branch information
rusty1s committed Nov 29, 2019
1 parent 7faf260 commit eacb7d3e24a28aa50d7ae6d20a42676bc8ca1536
Showing with 58 additions and 2 deletions.
  1. +1 −0 .travis.yml
  2. +2 −2 test/transforms/test_gdc.py
  3. +15 −0 test/utils/test_convert.py
  4. +3 −0 torch_geometric/utils/__init__.py
  5. +37 −0 torch_geometric/utils/convert.py
@@ -43,6 +43,7 @@ install:
- pip install torch-cluster
- pip install torch-spline-conv
- pip install cython && pip install gdist
- pip install trimesh
- pip install pycodestyle
- pip install flake8
- pip install codecov
@@ -17,7 +17,7 @@ def test_gdc():
data = gdc(data)
mat = to_dense_adj(data.edge_index, edge_attr=data.edge_attr).squeeze()
assert torch.all(mat >= -1e-8)
assert torch.allclose(mat, mat.t(), atol=1e-6)
assert torch.allclose(mat, mat.t(), atol=1e-4)

data = Data(edge_index=edge_index, num_nodes=5)
gdc = GDC(self_loop_weight=1, normalization_in='sym',
@@ -28,7 +28,7 @@ def test_gdc():
data = gdc(data)
mat = to_dense_adj(data.edge_index, edge_attr=data.edge_attr).squeeze()
assert torch.all(mat >= -1e-8)
assert torch.allclose(mat, mat.t(), atol=1e-6)
assert torch.allclose(mat, mat.t(), atol=1e-4)

data = Data(edge_index=edge_index, num_nodes=5)
gdc = GDC(self_loop_weight=1, normalization_in='col',
@@ -5,6 +5,7 @@
from torch_geometric.data import Data
from torch_geometric.utils import to_scipy_sparse_matrix, to_networkx
from torch_geometric.utils import from_scipy_sparse_matrix, from_networkx
from torch_geometric.utils import to_trimesh, from_trimesh
from torch_geometric.utils import subgraph


@@ -88,3 +89,17 @@ def test_vice_versa_convert():
assert G.is_directed() is True
G = nx.to_undirected(G)
assert G.is_directed() is False


def test_trimesh():
pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],
dtype=torch.float)
face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()

data = Data(pos=pos, face=face)
mesh = to_trimesh(data)
data = from_trimesh(mesh)

perm = (data.pos * torch.tensor([[1.0, 2.0, 3.0]])).sum(dim=-1).argsort()
assert pos.tolist() == data.pos[perm].tolist()
assert face.tolist() == perm[data.face].tolist()
@@ -18,6 +18,7 @@
from .geodesic import geodesic_distance
from .convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix
from .convert import to_networkx, from_networkx
from .convert import to_trimesh, from_trimesh
from .random import (erdos_renyi_graph, stochastic_blockmodel_graph,
barabasi_albert_graph)
from .negative_sampling import (negative_sampling,
@@ -54,6 +55,8 @@
'from_scipy_sparse_matrix',
'to_networkx',
'from_networkx',
'to_trimesh',
'from_trimesh',
'erdos_renyi_graph',
'stochastic_blockmodel_graph',
'barabasi_albert_graph',
@@ -3,6 +3,11 @@
import networkx as nx
import torch_geometric.data

try:
import trimesh
except ImportError:
trimesh = None

from .num_nodes import maybe_num_nodes


@@ -105,3 +110,35 @@ def from_networkx(G):
data.num_nodes = G.number_of_nodes()

return data


def to_trimesh(data):
r"""Converts a :class:`torch_geometric.data.Data` instance to a
:obj:`trimesh.Trimesh`.
Args:
data (torch_geometric.data.Data): The data object.
"""

if trimesh is None:
raise ImportError('Package `trimesh` could not be found.')

return trimesh.Trimesh(vertices=data.pos.detach().cpu().numpy(),
faces=data.face.detach().t().cpu().numpy())


def from_trimesh(mesh):
r"""Converts a :obj:`trimesh.Trimesh` to a
:class:`torch_geometric.data.Data` instance.
Args:
mesh (trimesh.Trimesh): A :obj:`trimesh` mesh.
"""

if trimesh is None:
raise ImportError('Package `trimesh` could not be found.')

pos = torch.from_numpy(mesh.vertices).to(torch.float)
face = torch.from_numpy(mesh.faces).t().contiguous()

return torch_geometric.data.Data(pos=pos, face=face)

0 comments on commit eacb7d3

Please sign in to comment.
You can’t perform that action at this time.