In [1]:
import numpy as np
from torch.utils.data import Dataset
import torchvision
import os
import h5py
import pickle  # TODO or use h5py instead?
import trimesh

import config as cfg
import dataset.augmentation as Transforms

In [2]:
class DatasetModelnet40(Dataset):
    
    def __init__(self, split, noise_type):
        dataset_path = cfg.M40_PATH
        categories = np.arange(20) if split in ["train", "val"] else np.arange(20, 40)
        split = "test" if split == "val" else split  # ModelNet40 has no validation set - use cat 0-19 with test set

        self.samples, self.labels = self.get_samples(dataset_path, split, categories)
        self.transforms = self.get_transforms(split, noise_type)

    def __len__(self):
        return self.samples.shape[0]

    def __getitem__(self, item):
        sample = {'points': self.samples[item, :, :], 'label': self.labels[item], 'idx': np.array(item, dtype=np.int32)}

        if self.transforms:
            sample = self.transforms(sample)
        return sample

    def get_transforms(self, split, noise_type):
        # prepare augmentations
        if noise_type == "clean":
            # 1-1 correspondence for each point (resample first before splitting), no noise
            if split == "train":
                transforms = [Transforms.Resampler(1024),
                              Transforms.SplitSourceRef(),
                              Transforms.Scale(), Transforms.Shear(), Transforms.Mirror(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.ShufflePoints()]
            else:
                transforms = [Transforms.SetDeterministic(),
                              Transforms.FixedResampler(1024),
                              Transforms.SplitSourceRef(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.ShufflePoints()]
        elif noise_type == "jitter":
            # Points randomly sampled (might not have perfect correspondence), gaussian noise to position
            if split == "train":
                transforms = [Transforms.SplitSourceRef(),
                              Transforms.Scale(), Transforms.Shear(), Transforms.Mirror(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.Resampler(1024),
                              Transforms.RandomJitter(),
                              Transforms.ShufflePoints()]
            else:
                transforms = [Transforms.SetDeterministic(),
                              Transforms.SplitSourceRef(),
                              Transforms.RandomTransformSE3_euler(),
                              Transforms.Resampler(1024),
                              Transforms.RandomJitter(),
                              Transforms.ShufflePoints()]
        else:
            raise ValueError(f"Noise type {noise_type} not supported for ModelNet40.")

        return torchvision.transforms.Compose(transforms)

    def get_samples(self, dataset_path, split, categories):
        filelist = [os.path.join(dataset_path, file.strip().split("/")[-1])
                   for file in open(os.path.join(dataset_path, f'{split}_files.txt'))]

        all_data = []
        all_labels = []
        for fi, fname in enumerate(filelist):
            f = h5py.File(fname, mode='r')
            data = np.concatenate([f['data'][:], f['normal'][:]], axis=-1)
            labels = f['label'][:].flatten().astype(np.int64)

            if categories is not None:  # Filter out unwanted categories
                mask = np.isin(labels, categories).flatten()
                data = data[mask, ...]
                labels = labels[mask, ...]

            all_data.append(data)
            all_labels.append(labels)
        all_data = np.concatenate(all_data, axis=0)
        all_labels = np.concatenate(all_labels, axis=0)
        return all_data, all_labels

In [9]:
if __name__ == '__main__':
    dataset = DatasetModelnet40(split = "train", noise_type="clean")
    #print(len(dataset))
    print(dataset[0])

{'label': 7, 'idx': array(0, dtype=int32), 'points_raw': array([[-0.28381357,  0.03711117, -0.32101122,  0.615526  ,  0.492254  ,
        -0.615478  ],
       [ 0.5578954 , -0.41320965, -0.15411651, -0.270454  ,  0.923956  ,
         0.27048   ],
       [-0.35169974,  0.20326908, -0.42259684,  0.653771  , -0.381099  ,
        -0.653717  ],
       ...,
       [ 0.38058695, -0.20426388,  0.4324199 ,  0.707123  ,  0.0021871 ,
        -0.707087  ],
       [ 0.28581825, -0.44051182,  0.34911266, -0.00156324,  0.999997  ,
         0.00156325],
       [-0.25075912, -0.10793268, -0.22646542, -0.0033314 ,  0.999989  ,
         0.00333113]], dtype=float32), 'points_src': array([[ 0.08782253,  0.14693572,  0.48750055,  0.95640266,  0.11370768,
         0.26900512],
       [ 0.38300854,  0.1162279 ,  0.21228456, -0.9563935 , -0.11376374,
        -0.26901668],
       [ 0.4660697 ,  0.68940806,  0.6999218 , -0.13051961,  0.9519471 ,
         0.27705857],
       ...,
       [ 0.60018444,  0.19838978,