# Brain tumor segmentation model using MONAI

## Quick overview of steps:

1. Build file lists (images, masks).
1. Build MONAI transforms for images & masks.
1. Create Dataset / Dataloader (CacheDataset for speed).
1. Create model (MONAI UNet for 2D).
1. Load pretrained checkpoint (if available) into model.
1. Define loss, optimizer, scheduler.
1. Train (with validation, metrics, checkpointing).
1. Fine-tuning tips (freeze, low LR, augmentations, etc.).

## 0. Imports

In [14]:
import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchinfo import summary

from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ToTensord,
    RandFlipd, RandRotate90d, Compose
)
from monai.data import Dataset, CacheDataset
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism

from tqdm import tqdm

set_determinism(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.__version__

(device(type='cuda'), '2.6.0+cu126')

## 1. file lists

In [4]:
# Define paths to the image and mask directories
IMAGE_DIR = Path("brain_tumor_dataset/images")
MASK_DIR = Path("brain_tumor_dataset/masks")

# Get all image and mask file paths, assuming they are PNG files
images = list(IMAGE_DIR.glob("*.png"))
masks = list(MASK_DIR.glob("*.png"))

# Create a list of dictionaries, pairing each image with its corresponding mask.
# This assumes that the files in both directories are sorted in a corresponding order.
data_dicts = [
    {
        "image": img, 
        "mask": m
        
    } for img, m in zip(images, masks)
]

# --- Data Splitting ---
import random
# Set a random seed for reproducibility, ensuring the shuffle is the same every time.
random.seed(42) 
# Shuffle the data dictionaries in place to randomize the dataset.
random.shuffle(data_dicts)

# Calculate the split index for an 80/20 train-test split.
n_val = int(0.8 * len(data_dicts)) 
# Split the data into training and testing sets.
train_files = data_dicts[:n_val]
test_files = data_dicts[n_val:]

# Print a sample from each set to verify the structure.
print("Sample from training set:", train_files[5])
print("Sample from testing set:", test_files[5])
print (f"Length of training set: {len(train_files)}")
print (f"Length of testing set: {len(test_files)}")



Sample from training set: {'image': WindowsPath('brain_tumor_dataset/images/1041.png'), 'mask': WindowsPath('brain_tumor_dataset/masks/1041.png')}
Sample from testing set: {'image': WindowsPath('brain_tumor_dataset/images/1284.png'), 'mask': WindowsPath('brain_tumor_dataset/masks/1284.png')}
Length of training set: 2451
Length of testing set: 613


## 2. transforms 2D

In [5]:
# Define the sequence of transformations for the training data. These include data augmentation steps.
train_transforms = Compose([ 
    LoadImaged(keys=["image", "mask"]),           # Loads the image and mask data from the file paths specified in the dictionary.
    
    EnsureChannelFirstd(keys=["image", "mask"]),  # Ensures the data format is [Channel, Height, Width], which is standard for PyTorch.
    
    ScaleIntensityd(keys=["image"]),              # Normalizes the intensity values of the image (e.g., to a [0, 1] range). 
    
    RandFlipd(keys=["image", "mask"],             # Randomly flips the image and mask for data augmentation.
                prob=0.5, spatial_axis=0),        # Flips along the first spatial axis (e.g., horizontal) with a 50% probability.
    
    RandRotate90d(keys=["image", "mask"],         # Randomly rotates the image and mask in 90-degree increments for augmentation.
                prob=0.5, max_k=3),               # Applies rotation with a 50% probability, up to 3 times (90, 180, or 270 degrees).
    
    ToTensord(keys=["image", "mask"]),            # Converts the image and mask from NumPy arrays to PyTorch Tensors.
])

# Define the sequence of transformations for the testing/validation data.
# Note: This pipeline does not include random augmentations (like flip or rotate) to ensure consistent evaluation.
test_transforms = Compose([ 
    LoadImaged(keys=["image", "mask"]),           # Loads the image and mask data from the file paths specified in the dictionary.
    
    EnsureChannelFirstd(keys=["image", "mask"]),  # Ensures the data format is [Channel, Height, Width].
    
    ScaleIntensityd(keys=["image"]),              # Normalizes the intensity values of the image.
    
    ToTensord(keys=["image", "mask"]),            # Converts the image and mask from NumPy arrays to PyTorch Tensors.
])


train_transforms, test_transforms

(<monai.transforms.compose.Compose at 0x2cb7f329fc0>,
 <monai.transforms.compose.Compose at 0x2cb7fa7c5e0>)

## 3. Datasets and Dataloaders 

In [7]:
os.cpu_count(), os.cpu_count()//2

(16, 8)

In [9]:
# --- Create MONAI Datasets and PyTorch DataLoaders ---

# Create a training dataset with caching.
train_DS = CacheDataset(
    data=train_files,         # The list of file dictionaries for training.
    transform=train_transforms, # The transformations (including augmentations) to apply.
    cache_rate=1.0            # Cache 100% of the transformed data in RAM for fast access during training epochs.
)

# Create a testing dataset. Caching is also beneficial here for faster repeated evaluations.
test_DS = CacheDataset(
    data=test_files,          # The list of file dictionaries for testing.
    transform=test_transforms,  # The transformations (without augmentations) to apply.
    cache_rate=1.0            # Suggestion: Caching the test set speeds up evaluation.
)


# Create a DataLoader for the training set.
train_loader = DataLoader(
    dataset=train_DS,         # The dataset to load from.
    batch_size=8,             # Number of samples per batch.
    shuffle=True,             # Shuffle the data at the beginning of each epoch to improve model generalization.
    num_workers=os.cpu_count()  # Use all available CPU cores to load data in parallel, preventing bottlenecks.
)

# Create a DataLoader for the testing set.
test_loader = DataLoader(
    dataset=test_DS,          # The dataset to load from.
    batch_size=4,             # A smaller batch size for evaluation is common.
    shuffle=False,            # Do not shuffle the test data to ensure consistent and reproducible evaluation.
    num_workers=os.cpu_count()//2 # Use half the CPU cores for loading test data.
)

Loading dataset:   0%|          | 0/2451 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 2451/2451 [00:33<00:00, 74.21it/s] 
Loading dataset: 100%|██████████| 613/613 [00:08<00:00, 73.53it/s]


In [11]:
train_DS, test_DS, train_loader, test_loader

(<monai.data.dataset.CacheDataset at 0x2cb072bf700>,
 <monai.data.dataset.CacheDataset at 0x2cb556eeef0>,
 <torch.utils.data.dataloader.DataLoader at 0x2cb274187f0>,
 <torch.utils.data.dataloader.DataLoader at 0x2cb7fdfeef0>)

## 4. Model

In [16]:
"""
The following code block initializes the U-Net model with specific parameters tailored for this brain tumor segmentation task.

U-Net Parameters Explained:
---------------------------
spatial_dims: Defines the dimensionality of the data. 
            For 2D MRI slices, this is 2. 
            For 3D MRI volumes, it would be 3.

in_channels: The number of input channels for the model. 
                For a single grayscale MRI slice, this is 1. 
                If using multi-modal data (like T1, T1ce, T2, FLAIR from BraTS), this would be 4.

out_channels: The number of output channels, which corresponds to the number of classes to be segmented. 
            For binary segmentation (tumor vs. background), this is 1.

channels: A tuple defining the number of feature channels at each level of the U-Net's encoder path. 
        The model starts with 16 channels, 
        then downsamples and increases the channel count to 32, 64, 128, and finally 256 at the bottleneck. 
        The decoder path will mirror this in reverse.

strides: A tuple defining the stride for the downsampling convolutions at each level of the encoder. 
        A stride of 2 halves the spatial dimensions (height and width) of the feature map at each step, which is standard for U-Nets.

num_res_units: The number of residual convolutional blocks at each level of the U-Net. 
            Using 2 units per level increases the model's capacity to learn complex features at different resolutions.
"""

# Set the number of classes for segmentation.
# For binary segmentation (tumor vs. background), this is 1.
num_classes = 1

# Set the number of input channels.
# Since we are using 2D grayscale MRI slices, this is 1.
in_channels = 1

# The number of output channels is equal to the number of classes.
out_channels = num_classes

model = UNet(
    spatial_dims=2,                 # Using 2D convolutions and operations.
    
    in_channels=in_channels,        # Number of channels in the input image.
    out_channels=out_channels,      # Number of channels in the output mask.
    
    channels=(16, 32, 64, 128, 256),# Feature maps at each level of the encoder.
    strides=(2, 2, 2, 2),           # Downsampling factor at each encoder level.
    
    num_res_units=2,                # Number of convolutional blocks per level.
).to(device) # Move the model to the specified device (GPU or CPU).



summary(
    model,
    # The input size should match the model's 2D configuration: (batch_size, channels, height, width)
    input_size=(1, in_channels, 224, 224),
    col_names=("input_size", "output_size", "num_params", "trainable"),
)



Layer (type:depth-idx)                                                                     Input Shape               Output Shape              Param #                   Trainable
UNet                                                                                       [1, 1, 224, 224]          [1, 1, 224, 224]          --                        True
├─Sequential: 1-1                                                                          [1, 1, 224, 224]          [1, 1, 224, 224]          --                        True
│    └─ResidualUnit: 2-1                                                                   [1, 1, 224, 224]          [1, 16, 112, 112]         --                        True
│    │    └─Conv2d: 3-1                                                                    [1, 1, 224, 224]          [1, 16, 112, 112]         160                       True
│    │    └─Sequential: 3-2                                                                [1, 1, 224, 224]          [1, 16, 

# 5. Loss and optimizer

In [17]:
"""
This block sets up the core components for training the model: 
the loss function, 
the optimizer, 
and a learning rate scheduler.

- Loss Function (DiceLoss): Measures how well the model's prediction matches the 
  ground truth mask. 
  Dice Loss is particularly effective for segmentation tasks as 
  it focuses on the overlap between the predicted and actual regions, which is 
  robust to class imbalance (e.g., small tumors in a large image).

- Optimizer (Adam): The algorithm that updates the model's weights to minimize the loss.

- Scheduler (ReduceLROnPlateau): A strategy to dynamically adjust the learning 
  rate during training. It will "reduce" the learning rate when the model's 
  performance on the validation set "plateaus" (stops improving). This helps the 
  model to fine-tune its weights and escape local minima.
"""
# Define the loss function. DiceLoss is excellent for segmentation tasks.
# sigmoid=True is necessary as our model outputs raw logits. 
# It applies a sigmoid to the output before calculating loss.
loss_fn = DiceLoss(sigmoid=True)

# Define the optimizer.
optimizer = torch.optim.Adam(model.parameters(), 
                            lr=1e-4)

# Define a learning rate scheduler. 
# It will reduce the learning rate when the validation loss stops improving ('min' mode).
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                        mode='min', 
                                                        factor=0.1, 
                                                        patience=10, 
                                                        verbose=True)

