In [1]:
# Custom Torch Dataset for our phantom / sinogram data

from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os

class LircstAnaDataset(Dataset):
    # Our data for our analytical simulation is laid out as such:
    # /data
    #   /<phantom-id>
    #       /meta.npy (contains metadata about the phantom, applies to all slices
    #       /phan-<slice-idx>.npy (contains the slice Ground Truth) (2x128x128)
    #       /sino-<slice-idx>.npy (contains the sinogram of the slice) (128x200x100)

    def __init__(self, data_dir: str, transform_phan: transforms=None, transform_sino: transforms=None):
        self.data_dir: str = data_dir
        self.transform_phan: transforms = transform_phan
        self.transform_sino: transforms = transform_sino
        # Get all the phantom_ids
        self.phantom_ids: list[str] = os.listdir(data_dir)
        self.phantom_ids.sort()
        # Iterate over all phantom_id directories and get all slice indices
        self.idxs: list[tuple[str, int]] = []
        for phantom_id in self.phantom_ids:
            phantom_dir = os.path.join(data_dir, phantom_id)
            slice_idxs = [int(f.split('-')[1].split('.')[0]) for f in os.listdir(phantom_dir) if f.startswith('phan-')]
            for idx in slice_idxs:
                self.idxs.append((phantom_id, idx))
        
    def __len__(self) -> int:
        return len(self.idxs)

    def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray, str]:
        # Return a tuple of the phantom slice, the sinogram, and the phantom_id (in case we need to look up the metadata)
        phantom_id, slice_idx = self.idxs[idx]
        phantom_dir = os.path.join(self.data_dir, phantom_id)
        phan = np.load(os.path.join(phantom_dir, f'phan-{slice_idx}.npy'))
        sino = np.load(os.path.join(phantom_dir, f'sino-{slice_idx}.npy'))

        if self.transform_phan:
            phan = self.transform_phan(phan)
        if self.transform_sino:
            sino = self.transform_sino(sino)
        
        return phan, sino, phantom_id


In [6]:
# Test the dataset

def test_dataset():
    dataset = LircstAnaDataset('/home/samnub/dev/lircst-ana/data/')
    print(len(dataset))
    phan, sino, phantom_id = dataset[0]
    print(phan.shape, sino.shape, phantom_id)

test_dataset()

1559
(2, 128, 128) (128, 200, 100) 1742413505
