In [14]:
import os
import json
import torch
from torch.utils.data import Dataset
from plyfile import PlyData
from glob import glob

class ScanNet200Dataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.scans = self._load_scans()

    def _load_scans(self):
        scans = []
        for scan_dir in os.listdir(self.root_dir): 
            scan_path = os.path.join(self.root_dir, scan_dir)
            if os.path.isdir(scan_path):
                ply_files = glob(os.path.join(scan_path, "*.ply"))
                json_files = glob(os.path.join(scan_path, "*.segs.json"))
                
                if ply_files and json_files:
                    ply_file = ply_files[0]
                    json_file = json_files[0]
                    scans.append((ply_file, json_file))
                else:
                    print(f"Missing files in {scan_path}")  # Debug print
        return scans

    def __len__(self):
        return len(self.scans)

    def _load_ply(self, ply_file):
        ply_data = PlyData.read(ply_file)
        vertices = torch.tensor(ply_data['vertex'].data.tolist(), dtype=torch.float32)
        return vertices

    def _load_json(self, json_file):
        with open(json_file, 'r') as f:
            data = json.load(f)
        segments = torch.tensor(data['segIndices'], dtype=torch.int64)
        return segments

    def __getitem__(self, idx):
        ply_file, json_file = self.scans[idx]
        vertices = self._load_ply(ply_file)
        segments = self._load_json(json_file)
        
        sample = {'vertices': vertices, 'segments': segments}
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample

Custom collate function to pad the sequence for resizing in the dataloader.

In [10]:
def custom_collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))  # Remove None values
    return {
        'vertices': torch.nn.utils.rnn.pad_sequence([item['vertices'] for item in batch], batch_first=True),
        'segments': torch.nn.utils.rnn.pad_sequence([item['segments'] for item in batch], batch_first=True)
    }

In [11]:
from path import Path

path = Path("../data/scannet200/scans") 
dataset = ScanNet200Dataset(root_dir=path)
print(dataset.__len__())

1513


In [13]:
for data in dataset:
    print(data['vertices'].shape, data['segments'].shape)


torch.Size([211497, 7]) torch.Size([211497])
torch.Size([138888, 7]) torch.Size([138888])
torch.Size([89242, 7]) torch.Size([89242])
torch.Size([49093, 7]) torch.Size([49093])
torch.Size([206210, 7]) torch.Size([206210])
torch.Size([182152, 7]) torch.Size([182152])
torch.Size([120683, 7]) torch.Size([120683])


KeyboardInterrupt: 

In [None]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=1, collate_fn=custom_collate_fn)

for batch in dataloader:
    print(batch['vertices'].shape, batch['segments'].shape)
