# Brain tumor segmentation model using MONAI

## Quick overview of steps:

1. Build file lists (images, masks).
1. Build MONAI transforms for images & masks.
1. Create Dataset / Dataloader (CacheDataset for speed).
1. Create model (MONAI UNet for 2D).
1. Load pretrained checkpoint (if available) into model.
1. Define loss, optimizer, scheduler.
1. Train (with validation, metrics, checkpointing).
1. Fine-tuning tips (freeze, low LR, augmentations, etc.).

## 0. Imports

In [23]:
import os
from pathlib import Path

import torch
from torch.utils.data import DataLoader

from monai.transforms import (
    LoadImaged, EnsureChannelFirstd, ScaleIntensityd, ToTensord,
    RandFlipd, RandRotate90d, Compose
)
from monai.data import Dataset, CacheDataset
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism

from tqdm import tqdm

set_determinism(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device, torch.__version__

(device(type='cuda'), '2.6.0+cu126')

## 1. file lists

In [15]:
IMAGE_DIR = Path("brain_tumor_dataset/images")
MASK_DIR = Path("brain_tumor_dataset/masks")

# Get all images and masks 
images = list(IMAGE_DIR.glob("*.png"))

masks = list(MASK_DIR.glob("*.png"))

data_dicts = [
    {
        "image": img, 
        "mask": m
        
    } for img, m in zip(images, masks)
]


# Split (simply shuffle first)
import random
random.seed(42) # set random seed for reproducibility

random.shuffle(data_dicts)
n_val = int( 0.8 * len(data_dicts)) # 80% for training -> 20% for testing

train_files = data_dicts[:n_val]
test_files = data_dicts[n_val:]

print(train_files[5])
print(test_files[5])

{'image': WindowsPath('brain_tumor_dataset/images/1041.png'), 'mask': WindowsPath('brain_tumor_dataset/masks/1041.png')}
{'image': WindowsPath('brain_tumor_dataset/images/1284.png'), 'mask': WindowsPath('brain_tumor_dataset/masks/1284.png')}


## 2. transforms 2D

In [None]:
train_transforms = Compose([
    
    LoadImaged(keys=["image", "mask"]),
    EnsureChannelFirstd(keys=["image", "mask"]),
    ScaleIntensityd(keys=["images"]),
    RandFlipd(keys=["image", "mask"], 
                prob = 0.5, spatial_axis= 0),
    RandRotate90d(keys=["image", "mask"],
                prob = 0.5, max_k= 3),
    ToTensord(keys=["image", "mask"]),
    
])

test_transforms = Compose([
    LoadImaged(keys=["image", "mask"]),
    EnsureChannelFirstd(keys=["image", "mask"]),
    ScaleIntensityd(keys=["images"]),
    ToTensord(keys=["image", "mask"]),
])