In [4]:
# data loading, implement __getitem__ , to load NIfTI  files on demand
import os
import torch
import pandas as pd
import nibabel as nib
import numpy as np
from torch.utils.data import Dataset

In [5]:
class MedicalVolumeDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        # list filenames (acts like your CSV "filename" column)
        self.filenames = sorted(
            f for f in os.listdir(image_dir)
            if f.endswith(".nii.gz") and not f.startswith("._")
        )

    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        filename = self.filenames[idx]

        image_path = os.path.join(self.image_dir, filename)
        label_path = os.path.join(self.label_dir, filename)

        # load nifti volumes
        volume = nib.load(image_path).get_fdata()
        label = nib.load(label_path).get_fdata()

        volume = np.array(volume, dtype= np.float32)
        label = np.array(label, dtype=np.int64)

        volume = torch.from_numpy(volume).unsqueeze(0) # (1, D, H, W)
        label = torch.from_numpy(label)

        if self.transform:
            volume, label = self.transform(volume, label)

        return volume, label

In [6]:
# using the dataset
dataset = MedicalVolumeDataset(
    image_dir="../Task09_Spleen/imagesTr", 
    label_dir= "../Task09_Spleen/labelsTr"
)

volume, label = dataset[0]

print("Volume shape:", volume.shape)
print("Label:", label)


Volume shape: torch.Size([1, 512, 512, 55])
Label: tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 

In [9]:
# dataset to DataLoader 
from torch.utils.data import DataLoader

loader = DataLoader(
    dataset,
    batch_size = 1,
    shuffle= True
)

for volumes, lables in loader:
    print(volumes.shape)  # (B, 1, D, H, W)
    print(label)
    break

torch.Size([1, 1, 512, 512, 83])
tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0