In [22]:
import os
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
import nibabel as nib

In [29]:
# create medical volume dataset
class MedicalVolumeDataset(Dataset):
    def __init__(self, image_dir, label_dir, transform=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        # list filenames (acts like your CSV "filename" column)
        self.filenames = sorted(
            f for f in os.listdir(image_dir)
            if f.endswith(".nii.gz") and not f.startswith("._")
        )

        # len() --> Dataset size
    def __len__(self):
        return len(self.filenames)

    # getitem() --> one sample
    def __getitem__(self, idx):
        # get filename and label
        filename = self.filenames[idx]

        # load volume
        image_path = os.path.join(self.image_dir, filename)
        label_path = os.path.join(self.label_dir, filename)

        volume = nib.load(image_path).get_fdata()  # (D, H, W)
        label= nib.load(label_path).get_fdata()  # (D, H, W)

        # convert to tensor
        volume = torch.tensor(volume, dtype = torch.float32)
        label = torch.tensor(label, dtype= torch.long)

        # add channel dimension (1, D, H, W )
        volume = volume.unsqueeze(0)

        # apply transform
        if self.transform:
            volume, label = self.transform(volume, label)

        return volume, label


In [30]:
# using the dataset
dataset = MedicalVolumeDataset(
    image_dir="../Task09_Spleen/imagesTr", 
    label_dir= "../Task09_Spleen/labelsTr"
)

#print("Total samples:", len(dataset))

volume, label = dataset[0]

print("volume shape:", volume.shape) # (1, D, H, W)
print("Label:", label.shape) # (D, W, H)
print("Unique labels:", torch.unique(label))

volume shape: torch.Size([1, 512, 512, 55])
Label: torch.Size([512, 512, 55])
Unique labels: tensor([0, 1])


In [32]:
# dataset + dataloader 
from torch.utils.data import DataLoader

loader = DataLoader (
    dataset, 
    batch_size = 1, 
    shuffle = True
)

for volumes, labels in loader:
    print(volumes.shape) # (B, 1, D, H, W)
    print(labels)
    break

torch.Size([1, 1, 512, 512, 77])
tensor([[[[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         ...,

         [[0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          ...,
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0],
          [0, 0, 0,  ..., 0, 0, 0]],

         [[0, 0, 0, 