In [43]:
import os
import torch

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from segment.data.aapm_thoracic.aapm import Patient

In [71]:
class Patient:
    """
    AAPM Thoracic patient.

    Parameters
    ----------
    path : str
        Path to patient records.

    patient_number : int 
        Integer value corresponding to patient ID.
        - Must be between [0, 59]

    References
    ----------
    """
    def __init__(self, path, patient_number):
        self.path = path
        self.patient_number = patient_number
        self.img = f'image_{patient_number}.pth'
        self.mask = f'label_{patient_number}.pth'

    def __repr__(self):
        img = self.load_img()
        return f'AAPM Thoracic Patient: {self.patient_number} '\
               f'\n Image Size: {img.size()}'

    def load_img(self):
        img_path = os.path.join(self.path, self.img)
        return torch.load(img_path)

    def load_mask(self):
        mask_path = os.path.join(self.path, self.mask)
        try:
            return torch.load(mask_path)
        except:
            raise FileNotFoundError(f'Patient {self.patient_number} '\
                                    f'has no label mask.')

In [72]:
class AAPM(Dataset):
    """
    AAPM Thoracic Dataset.

    Parameters
    ----------
    path : str
        Path to AAPM data.

    References
    ----------
    """
    def __init__(self, path, semisupervised=False, transform=None):
        self.path = path
        self.transform = transform
        self.semisupervised

    def __repr__(self):
        return f'AAPM Thoracic Dataset.'

    def __len__(self):
        return 50

    def __getitem__(self, idx):
        patient = Patient(self.path, idx)
        try:
            mask = patient.load_mask()
        except:
            print(f'Patient {idx} has no labels')
            if self.semisupervised:
                img = patient.load_img()
            
        if self.transform is not None:
            img = self.transform(img)

        return img, mask

In [73]:
PATH = '/mnt/USB/AAPM17CTSegChallenge/pth'

In [74]:
patient0 = Patient(PATH, 4)
patient0

AAPM Thoracic Patient: 4 
 Image Size: torch.Size([512, 512, 165])

In [69]:
img = patient0.load_img()

In [70]:
mask = patient0.load_mask()

FileNotFoundError: Patient 4 has no label mask.

In [26]:
print(f'Image: {img.shape}')
print(f'Mask:  {mask.shape}')

Image: torch.Size([512, 512, 122])
Mask:  torch.Size([512, 512, 122])


In [30]:
aapm = AAPM(PATH)

In [34]:
loader = DataLoader(aapm, 1)

In [35]:
for idx, (img, mask) in enumerate(loader):
    print(f'Image: {img.shape}, Mask: {mask.shape}')

Image: torch.Size([1, 512, 512, 122]), Mask: torch.Size([1, 512, 512, 122])
Image: torch.Size([1, 512, 512, 134]), Mask: torch.Size([1, 512, 512, 134])
Image: torch.Size([1, 512, 512, 140]), Mask: torch.Size([1, 512, 512, 140])
Image: torch.Size([1, 512, 512, 152]), Mask: torch.Size([1, 512, 512, 152])


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/USB/AAPM17CTSegChallenge/pth/label_4.pth'