In [1]:
# data loading, implement __getitem__ , to load NIfTI  files on demand
import os
import torch
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset

In [2]:
class MedicalVolumeDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        # list filenames (acts like your CSV "filename" column)
        self.filenames = sorted(
            f for f in os.listdir(image_dir)
            if f.endswith(".nii.gz") and not f.startswith("._")
        )

    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]

        image_path = os.path.join(self.image_dir, filename)
        label_path = os.path.join(self.label_dir, filename)

        # load nifti volumes
        volume = nib.load(image_path).get_fdata()
        label = nib.load(label_path).get_fdata()

        volume = np.array(volume, dtype= np.float32)
        label = np.array(label, dtype=np.int64)

        volume = torch.from_numpy(volume).unsqueeze(0) # (1, D, H, W)
        label = torch.from_numpy(label)

        if self.transform:
            volume, label = self.transform(volume, label)

        return volume, label

In [3]:
# using the dataset
dataset = MedicalVolumeDataset(
    image_dir="../Task09_Spleen/imagesTr", 
    label_dir= "../Task09_Spleen/labelsTr"
)

volume, label = dataset[0]

print("single Volume shape:", volume.shape)
print("single Label shape:", label.shape)


single Volume shape: torch.Size([1, 512, 512, 55])
single Label shape: torch.Size([512, 512, 55])


In [11]:
from torch.utils.data import Dataset, DataLoader
loader = DataLoader(
    dataset,
    batch_size = 2,    # try changing to 1,2,4
    shuffle= True,
    num_workers =0     # keep 0 for windows
)

In [12]:
for volumes, label in loader:
    print("Batch volume shape:", volumes.shape)
    print("Batch label shape:", label.shape)
    break

RuntimeError: stack expects each tensor to be equal size, but got [1, 512, 512, 164] at entry 0 and [1, 512, 512, 40] at entry 1

In [20]:
import torch.nn.functional as F 
class ResizeNormalize3D:
    def __init__(self, target_shape):
        self.target_shape = target_shape
    
    def __call__(self, volume, label):
        # volume: (1, D, H, W)
        # label: (D, H, W)

        # Resize volume 
        volume = volume.unsqueeze(0)  # (1, 1, D, H, W)
        volume = F.interpolate(
            volume, 
            size=self.target_shape,
            mode="trilinear",
            align_corners=False
        )
        volume  = volume.squeeze(0) #(1, D, H, W)

        # resize label
        label= label.float()
        label = label.unsqueeze(0).unsqueeze(0)  # (1, 1, D, H, W)
        label = F.interpolate(
            label,
            size = self.target_shape,
            mode="nearest"
        )
        label = label.squeeze(0).squeeze(0) #(D, H, W)

        # normalize volume to 0-1
        volume = (volume-volume.min()) / (volume.max() - volume.min() + 1e-8)

        return volume, label

In [21]:
TARGET_SHAPE = (64, 256, 256) # (D, H, W) 
transform = ResizeNormalize3D(target_shape= TARGET_SHAPE)

dataset = MedicalVolumeDataset(
    image_dir = "../Task09_Spleen/imagesTr",
    label_dir="../Task09_Spleen/labelsTr",
    transform = transform
)

In [24]:
# create a DataLoader 
loader = DataLoader(
    dataset, 
    batch_size=4,
    shuffle=True,
    num_workers=0
)

In [25]:
# test batch shapes
for volumes, label in loader:
    print("Batch volume shape:", volumes.shape)
    print("Batch label shape:", label.shape)
    break

Batch volume shape: torch.Size([4, 1, 64, 256, 256])
Batch label shape: torch.Size([4, 64, 256, 256])
