In [2]:
import trimesh
import networkx
from torch.utils.data import Dataset
import torch
from pathlib import Path

In [7]:
class STLDataset(Dataset):

    def __init__(self, root_path, tensor_size):
        self.stl_path = root_path
        self.stl_list = sorted(self._get_filenames(self.stl_path))
        self.tensor_size = tensor_size

    def __getitem__(self, idx):
        mesh = trimesh.load_mesh(self.stl_list[idx])
        adj = torch.from_numpy(networkx.adjacency_matrix(trimesh.graph.vertex_adjacency_graph(mesh)).toarray())
        triangles_num = torch.tensor(mesh.faces.shape[0])
        triangle_vertices_coords = self.fix_size_coords(torch.from_numpy(mesh.vertices))
        adj = self.fix_size_adj(adj, tensor_size=self.tensor_size)
        return adj, triangles_num, triangle_vertices_coords

    def __len__(self):
        return len(self.stl_list)
    
    @staticmethod
    def _get_filenames(path):
        return [f for f in path.iterdir() if f.is_file()]

    @staticmethod
    def fix_size_adj(input_tensor, tensor_size=42):
        if input_tensor.shape[0] < tensor_size:
            zeros = torch.zeros(input_tensor.shape[0], tensor_size - input_tensor.shape[0])
            tensor = torch.cat([input_tensor, zeros], dim=1)
    
            zeros = torch.zeros(tensor_size - input_tensor.shape[0], tensor_size)
            tensor = torch.cat([tensor, zeros], dim=0)
            return tensor
        elif input_tensor.shape[0] > tensor_size:
            return input_tensor[:tensor_size, :tensor_size]
        else:
            return input_tensor

    @staticmethod
    def fix_size_coords(input_tensor, tensor_size=42):
        if input_tensor.shape[0] < tensor_size:
            zeros = torch.zeros(tensor_size - input_tensor.shape[0], input_tensor.shape[1])
            tensor = torch.cat([input_tensor, zeros], dim=0)

            return tensor
        elif input_tensor.shape[0] > tensor_size:
            return input_tensor[:tensor_size]
        else:
            return input_tensor

In [8]:
input_dir = "../data/Thingi10K/models"
dataset = STLDataset(Path(input_dir), tensor_size=750)

In [9]:
batch_size = 8
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, pin_memory=True
)

In [10]:
batch = next(iter(data_loader))

In [16]:
ma_mesh = trimesh.load_mesh('../data/Thingi10K/models/test.stl')
print(ma_mesh.vertices)

[[ 0.74320864  2.65091561 14.21195256]
 [-1.2008943   2.82786574 14.21195256]
 [ 0.74320864  2.65091561 12.33923256]
 [ 1.23969685 24.1757437  12.33923256]
 [-1.35483165 23.98082942 12.33923256]
 [ 1.23969685 24.1757437  14.21195256]
 [-1.35483165 23.98082942 14.21195256]
 [-3.79850166 23.18254406 12.33923256]
 [-3.79850166 23.18254406 14.21195256]
 [-5.90395909 21.89065524 12.33923256]
 [-5.90395909 21.89065524 14.21195256]
 [-7.56142887 20.27367705 12.33923256]
 [-7.56142887 20.27367705 14.21195256]
 [-8.67007384 18.64810448 12.33923256]
 [-8.67007384 18.64810448 14.21195256]
 [-9.50149604 16.7355093  12.33923256]
 [-9.50149604 16.7355093  14.21195256]
 [-9.95865138 14.61196434 12.33923256]
 [-9.95865138 14.61196434 14.21195256]
 [-9.97600615 12.39373942 12.33923256]
 [-9.97600615 12.39373942 14.21195256]
 [-9.53845737 10.219027   12.33923256]
 [-9.53845737 10.219027   14.21195256]
 [-8.68703131  8.2206507  12.33923256]
 [-8.68703131  8.2206507  14.21195256]
 [-7.5081146   6.50034881

In [14]:
adjacency_matrices, triangles_num, coordinates = batch
print(adjacency_matrices)
print(adjacency_matrices.shape)
print(triangles_num)
print(coordinates)
print(coordinates.shape)

tensor([[[0., 1., 1.,  ..., 0., 0., 0.],
         [1., 0., 1.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
torch.Size([1, 750, 750])
tensor([160])
tensor([[[ 0.7432,  2.6509, 14.2120],
         [-1.2009,  2.8279, 14.2120],
         [ 0.7432,  2.6509, 12.3392],
         [ 1.2397, 24.1757, 12.3392],
         [-1.3548, 23.9808, 12.3392],
         [ 1.2397, 24.1757, 14.2120],
         [-1.3548, 23.9808, 14.2120],
         [-3.7985, 23.1825, 12.3392],
         [-3.7985, 23.1825, 14.2120],
         [-5.9040, 21.8907, 12.3392],
         [-5.9040, 21.8907, 14.2120],
         [-7.5614, 20.2737, 12.3392],
         [-7.5614, 20.2737, 14.2120],
         [-8.6701, 18.6481, 12.3392],
         [-8.6701, 18.6481, 14.2120],
         [-9.5015, 16.7355, 12.3392],
         [-9.5015, 16.7355, 14.2120],
         [-9.9587, 14.6120, 12.3392],
         [-9.9