Skip to content
Permalink
Browse files

first io package contribution

  • Loading branch information...
rusty1s committed Oct 2, 2019
1 parent d2c8ef6 commit b62e1b085c39c210dcefc04b8d0010bc2debefcb
Showing with 88 additions and 6 deletions.
  1. +0 −1 .coveragerc
  2. +8 −0 test/io/example1.off
  3. +7 −0 test/io/example2.off
  4. +31 −0 test/io/test_off.py
  5. +2 −1 torch_geometric/io/__init__.py
  6. +40 −4 torch_geometric/io/off.py
@@ -2,7 +2,6 @@
source=torch_geometric
omit=
torch_geometric/datasets/*
torch_geometric/io/*
torch_geometric/data/extract.py
torch_geometric/nn/data_parallel.py
[report]
@@ -0,0 +1,8 @@
OFF
4 2 0
0.0 0.0 0.0
0.0 1.0 0.0
1.0 0.0 0.0
1.0 1.0 0.0
3 0 1 2
3 1 2 3
@@ -0,0 +1,7 @@
OFF
4 1 0
0.0 0.0 0.0
0.0 1.0 0.0
1.0 0.0 0.0
1.0 1.0 0.0
4 0 1 2 3
@@ -0,0 +1,31 @@
import os
import os.path as osp

import torch
from torch_geometric.data import Data
from torch_geometric.io import read_off, write_off


def test_read_off():
data = read_off(osp.join('test', 'io', 'example1.off'))
assert len(data) == 2
assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]
assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]]

data = read_off(osp.join('test', 'io', 'example2.off'))
assert len(data) == 2
assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]
assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]]


def test_write_off():
pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])
face = torch.tensor([[0, 1], [1, 2], [2, 3]])

path = osp.join('test', 'io', 'example.off')
write_off(Data(pos=pos, face=face), path)
data = read_off(path)
os.unlink(path)

assert data.pos.tolist() == pos.tolist()
assert data.face.tolist() == face.tolist()
@@ -4,7 +4,7 @@
from .ply import read_ply
from .obj import read_obj
from .sdf import read_sdf, parse_sdf
from .off import read_off, parse_off
from .off import read_off, parse_off, write_off
from .npz import read_npz, parse_npz

__all__ = [
@@ -17,6 +17,7 @@
'read_sdf',
'parse_sdf',
'read_off',
'write_off',
'parse_off',
'read_npz',
'parse_npz',
@@ -1,7 +1,12 @@
import re

import torch
from torch._tensor_str import PRINT_OPTS
from torch_geometric.io import parse_txt_array
from torch_geometric.data import Data

REGEX = r'[\[\]\(\),]|tensor|(\s)\1{2,}'


def parse_off(src):
# Some files may contain a bug and do not have a carriage return after OFF.
@@ -33,14 +38,45 @@ def face_to_tri(face):
rect = rect.to(torch.int64)

if rect.numel() > 0:
first, second = rect[:, [0, 1, 2]], rect[:, [0, 2, 3]]
first, second = rect[:, [0, 1, 2]], rect[:, [1, 2, 3]]
return torch.cat([triangle, first, second], dim=0).t().contiguous()
else:
first, second = rect, rect

return torch.cat([triangle, first, second], dim=0).t().contiguous()
return triangle.t().contiguous()


def read_off(path):
r"""Reads an OFF (Object File Format) file, returning both the position of
nodes and their connectivity in a :class:`torch_geometric.data.Data`
object.
Args:
path (str): The path to the file.
"""
with open(path, 'r') as f:
src = f.read().split('\n')[:-1]
return parse_off(src)


def write_off(data, path):
r"""Writes a :class:`torch_geometric.data.Data` object to an OFF (Object
File Format) file.
Args:
data (:class:`torch_geometric.data.Data`): The data object.
path (str): The path to the file.
"""
num_nodes, num_faces = data.pos.size(0), data.face.size(1)

face = data.face.t()
num_vertices = torch.full((num_faces, 1), face.size(1), dtype=torch.long)
face = torch.cat([num_vertices, face], dim=-1)

threshold = PRINT_OPTS.threshold
torch.set_printoptions(threshold=float('inf'))
with open(path, 'w') as f:
f.write('OFF\n{} {} 0\n'.format(num_nodes, num_faces))
f.write(re.sub(REGEX, r'', data.pos.__repr__()))
f.write('\n')
f.write(re.sub(REGEX, r'', face.__repr__()))
f.write('\n')
torch.set_printoptions(threshold=threshold)

0 comments on commit b62e1b0

Please sign in to comment.
You can’t perform that action at this time.