In [8]:
import numpy as np
from torch.utils.data import Dataset
import torchvision
import os
import h5py
import pickle  # TODO or use h5py instead?
import trimesh
import open3d as o3d
import glob

import config as cfg
import dataset.augmentation as Transforms

In [28]:
class CustomDataset(Dataset):
    
    def __init__(self, split, noise_type):
        dataset_path = cfg.CUSTOM_PATH

        self.samples, self.labels = self.get_samples(dataset_path, split)
        self.transforms = self.get_transforms(split, noise_type)

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, item):
        sample = {'points': self.samples[item, :, :], 'label': self.labels[item], 'idx': np.array(item, dtype=np.int32)}

        if self.transforms:
            sample = self.transforms(sample)
        return sample

    def get_transforms(self, split, noise_type):
        # prepare augmentations
        if noise_type == "clean":
            # 1-1 correspondence for each point (resample first before splitting), no noise
            if split == "train":
                transforms = [Transforms.Resampler(2048),
                              Transforms.SplitSourceRef(),
                              Transforms.Scale(), Transforms.Shear(), Transforms.Mirror(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.ShufflePoints()]
            else:
                transforms = [Transforms.SetDeterministic(),
                              Transforms.FixedResampler(2048),
                              Transforms.SplitSourceRef(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.ShufflePoints()]
        elif noise_type == "jitter":
            # Points randomly sampled (might not have perfect correspondence), gaussian noise to position
            if split == "train":
                transforms = [Transforms.SplitSourceRef(),
                              Transforms.Scale(), Transforms.Shear(), Transforms.Mirror(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.Resampler(2048),
                              Transforms.RandomJitter(),
                              Transforms.ShufflePoints()]
            else:
                transforms = [Transforms.SetDeterministic(),
                              Transforms.SplitSourceRef(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.Resampler(2048),
                              Transforms.RandomJitter(),
                              Transforms.ShufflePoints()]
        else:
            raise ValueError(f"Noise type {noise_type} not supported for CustomData.")

        return torchvision.transforms.Compose(transforms)

    def get_samples(self, dataset_path, split):
        if split == 'train':
            path = os.path.join(dataset_path, 'train_data')
        elif split == 'val':
            path = os.path.join(dataset_path, 'val_data')
        else:
            path = os.path.join(dataset_path, 'test_data')
            
        all_data = []
        all_labels = []
        for item in glob.glob(path + '/*.obj'):
            mesh = o3d.io.read_triangle_mesh(item)
            pcd = mesh.sample_points_uniformly(number_of_points=2048)
    
            xyz = np.array(pcd.points)
            data = xyz.astype(np.float32)
            labels = 0

            all_data.append(data)
            all_labels.append(labels)

        return np.array(all_data), np.array(all_labels)
    

In [29]:
if __name__ == '__main__':
    dataset = CustomDataset(split = 'train', noise_type='clean')
    print(len(dataset))
    print(dataset[0])

160
{'label': 0, 'idx': array(0, dtype=int32), 'points_raw': array([[-0.34620836, -0.2937608 , -0.00196991],
       [-0.43061003, -0.30287626,  0.0634876 ],
       [-0.3337034 , -0.2682038 , -0.14593735],
       ...,
       [-0.15968841, -0.42347932,  0.09959693],
       [-0.26971498, -0.44452402,  0.21975738],
       [-0.32791668, -0.30645508,  0.08058053]], dtype=float32), 'points_src': array([[-0.03257793, -0.5947261 ,  0.32602638],
       [ 0.2051293 , -0.6731661 ,  0.26315087],
       [-0.1286774 , -0.5606637 , -0.25961742],
       ...,
       [-0.23216912, -0.74582636,  0.22245607],
       [-0.02384359, -0.70644456,  0.38735688],
       [ 0.13447371, -0.65190655,  0.12604025]], dtype=float32), 'points_ref': array([[-0.36624637, -0.30454427,  0.1120059 ],
       [-0.34865776, -0.3240514 ,  0.16821665],
       [-0.43703634, -0.3382782 ,  0.21494418],
       ...,
       [-0.22368574, -0.42774227,  0.13192919],
       [-0.56998146, -0.22775926,  0.17650536],
       [-0.34087044, -0.2