In [2]:
import nibabel as nib
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from skimage.transform import rotate
from skimage.util import montage
import os
import numpy as np
import tarfile
import pandas as pd
from skimage.transform import resize
from scipy.ndimage import gaussian_filter

In [None]:
path_to_tar_file = "/kaggle/input/brats-2021-task1/BraTS2021_Training_Data.tar"
output_directory = "brats21-dataset-training-validation"

# Open the tar file in read mode
with tarfile.open(path_to_tar_file, "r") as tar_ref:
    # Extract all contents to the specified directory
    tar_ref.extractall(output_directory)

print(f"Extraction completed. Files are saved to {output_directory}")

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import nibabel as nib


class Brats2021LSTMDataset(Dataset):
    def __init__(self, data_dir=None, slice_window=3, train=True, transform=None):
        """
        Dataset for BraTS2021 with LSTM-CNN using lazy loading.
        Args:
            data_dir (str): Path to dataset directory.
            slice_window (int): Number of consecutive slices for LSTM input.
            train (bool): Whether in training mode.
            transform (callable): Transformations (e.g., normalization, augmentation).
        """
        self.data_dir = data_dir
        self.slice_window = slice_window
        self.train = train
        self.transform = transform  # Augmentation pipeline
        self.modalities = ['flair', 't1', 't1ce', 't2']

        # Collect paths for all subject directories
        if data_dir:
            self.subject_paths = [os.path.join(data_dir, subject) for subject in os.listdir(data_dir)
                                  if os.path.isdir(os.path.join(data_dir, subject))]
        else:
            self.subject_paths = []

    def __len__(self):
        return len(self.subject_paths)

    def __getitem__(self, index):
        # Load the subject data lazily
        subject_path = self.subject_paths[index]
        volume, label = self._load_subject(subject_path)
    
        # Normalize
        volume = self._normalize_volume(volume)
    
        if self.train:
            # Randomly sample slices for training
            depth = volume.shape[3]
            for _ in range(10):
                i = np.random.randint(depth - self.slice_window + 1)
                slice_vol = volume[:, :, :, i:i + self.slice_window]
                slice_label = label[:, :, i:i + self.slice_window]
                if not np.all(slice_vol == 0):
                    break
            else:
                raise ValueError(f"All slices are zero for subject {subject_path}")
    
            # Apply augmentation pipeline
            if self.transform:
                slice_vol, slice_label = self.apply_transforms(slice_vol, slice_label)
    
            slice_vol = torch.tensor(slice_vol, dtype=torch.float32).permute(3,0,1,2)
            slice_label = torch.tensor(slice_label, dtype=torch.long).permute(2,0,1)
            
            return slice_vol, slice_label
    
        else:
            # Return full volume for evaluation
            volume = torch.tensor(volume, dtype=torch.float32).permute(3,0,1,2)  # Shape: (depth, modalities, height, width,)
            label = torch.tensor(label, dtype=torch.long).permute(2,0,1)    # Shape: (depth, height, width)
    
            return volume, label 

    def _load_subject(self, subject_path):
        """
        Load modalities and segmentation labels for a given subject.
        Args:
            subject_path (str): Path to the subject's directory.
        Returns:
            volume (np.array): Array of shape (modalities, height, width, depth).
            label (np.array): Array of shape (height,width, depth).
        """
        volume = []
        label = None
        subject_name = os.path.basename(subject_path)

        for modal in self.modalities:
            modal_path = os.path.join(subject_path, f"{subject_name}_{modal}.nii.gz")
            if os.path.exists(modal_path):
                volume.append(nib.load(modal_path, mmap=True).get_fdata())
            else:
                raise FileNotFoundError(f"Missing modality: {modal_path}")

        label_path = os.path.join(subject_path, f"{subject_name}_seg.nii.gz")
        if os.path.exists(label_path):
            label = nib.load(label_path, mmap=True).get_fdata()
        else:
            raise FileNotFoundError(f"Missing segmentation: {label_path}")

        
        return np.asarray(volume), label

    @staticmethod
    def _normalize_volume(volume):
        """
        Normalize the volume to [0, 1].
        """
        return (volume - np.min(volume)) / (np.max(volume) - np.min(volume) + 1e-8)

In [None]:
from sklearn.model_selection import train_test_split
import os

# Directory containing all data
data_dir = '/kaggle/working/brats21-dataset-training-validation'

# List all subject directories
all_subjects = [os.path.join(data_dir, subject) for subject in os.listdir(data_dir)
                if os.path.isdir(os.path.join(data_dir, subject))]

# Split into train, validation, and test sets
train_subjects, temp_subjects = train_test_split(all_subjects, test_size=0.3, random_state=42)
val_subjects, test_subjects = train_test_split(temp_subjects, test_size=0.3, random_state=42)

# Print sizes of each split
print(f"Total Subjects: {len(all_subjects)}")
print(f"Training: {len(train_subjects)}, Validation: {len(val_subjects)}, Testing: {len(test_subjects)}")


In [None]:
data_dir = '/kaggle/working/brats21-dataset-training-validation'

# Training dataset with augmentation
train_dataset = Brats2021LSTMDataset(
    data_dir= data_dir,
    slice_window=3,
    train=True,
)
train_dataset.subject_paths = train_subjects

# Validation dataset without augmentation
val_dataset = Brats2021LSTMDataset(
    data_dir= data_dir,
    slice_window=3,
    train=False,
)
val_dataset.subject_paths = val_subjects

# Initialize testing dataset
test_dataset = Brats2021LSTMDataset(
    data_dir=None,
    slice_window=3,
    train=False,
)
test_dataset.subject_paths = test_subjects  

from torch.utils.data import DataLoader

# Training DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)

# Validation DataLoader
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)

# Testing DataLoader
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2)

# Example training loop
for batch_idx, (images, labels) in enumerate(train_loader):
    print(f"Batch {batch_idx}:")
    print(f"Images shape: {images.shape}") 
    print(f"Labels shape: {labels.shape}")
    break