In [7]:
import trimesh
import networkx
from torch.utils.data import Dataset
import torch
from pathlib import Path

In [54]:
class STLDataset(Dataset):

    def __init__(self, root_path, tensor_size):
        self.stl_path = root_path
        self.stl_list = sorted(self._get_filenames(self.stl_path))
        self.tensor_size = tensor_size

    def __getitem__(self, idx):
        mesh = trimesh.load_mesh(self.stl_list[idx])
        adj = torch.from_numpy(networkx.adjacency_matrix(trimesh.graph.vertex_adjacency_graph(mesh)).toarray())
        facades = torch.tensor(mesh.faces.shape[0])
        coords = self.fix_size_coords(torch.from_numpy(mesh.vertices))
        adj = self.fix_size_adj(adj, tensor_size=self.tensor_size)
        print(adj.shape, facades.shape, coords.shape)
        return adj, facades, coords

    def __len__(self):
        return len(self.stl_list)
    
    @staticmethod
    def _get_filenames(path):
        return [f for f in path.iterdir() if f.is_file()]

    @staticmethod
    def fix_size_adj(input_tensor, tensor_size=42):
        if input_tensor.shape[0] < tensor_size:
            zeros = torch.zeros(input_tensor.shape[0], tensor_size - input_tensor.shape[0])
            tensor = torch.cat([input_tensor, zeros], dim=1)
    
            zeros = torch.zeros(tensor_size - input_tensor.shape[0], tensor_size)
            tensor = torch.cat([tensor, zeros], dim=0)
            return tensor
        elif input_tensor.shape[0] > tensor_size:
            return input_tensor[:tensor_size, :tensor_size]
        else:
            return input_tensor

    @staticmethod
    def fix_size_coords(input_tensor, tensor_size=42):
        if input_tensor.shape[0] < tensor_size:
            zeros = torch.zeros(tensor_size - input_tensor.shape[0], input_tensor.shape[1])
            tensor = torch.cat([input_tensor, zeros], dim=0)

            return tensor
        elif input_tensor.shape[0] > tensor_size:
            return input_tensor[:tensor_size]
        else:
            return input_tensor

In [55]:
input_dir = "../data/Thingi10K/models"
dataset = STLDataset(Path(input_dir), tensor_size=750)

In [56]:
batch_size = 8
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, pin_memory=True
)

In [57]:
batch = next(iter(data_loader))

torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
torch.Size([500, 500]) torch.Size([]) torch.Size([42, 3])
