In [1]:
import sys

from sympy.parsing.sympy_parser import transformations

sys.path.append('./packages')
sys.path.append('./wrappers')
from dataset_wrapper import initialise_dataset

In [2]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from monai.networks.nets import UNet
from monai.utils import set_determinism
from monai.data import Dataset, DataLoader
from monai.losses import DiceLoss
from monai.data import PersistentDataset
from monai.metrics import DiceMetric
import torch
import torch.optim as optim
import torch.nn.functional as F

In [3]:
# Modify to suit local machine
data_path = '../data/'

In [4]:
# DO NOT TOUCH
# Select what datatypes you would like to load from the dataset
# flair: Fluid-Attenuated Inversion Recovery. Highlights edema (swelling around tumor)
# t1: Standard T1-weighted scan. Good anatomical detail
# t1ce: T1-weighted with contrast enhancement. Shows areas where tumor enhances after injection of gadolinium
# t2: T2-weighted scan. Bright scan for fluids

modalities = ["flair", "t1", "t1ce", "t2"]

mod_dic = {}
for index, modality in enumerate(modalities):
    mod_dic[modality] = index

In [5]:
# This line should only every be ADDED to, and NOT subtracted from
# Import more functions if needed. Add them to transformations to apply them to data
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    ToTensord,
    AdjustContrastd,
    HistogramNormalized,
    NormalizeIntensityd,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CenterSpatialCropd,
    SpatialPadd,
)

from wrapped_transformations import N4BiasFieldCorrectionCustomd

<wrapped_transformations.N4BiasFieldCorrectionCustomd at 0x266b983d8d0>

In [6]:
# Pre-processing. Add as many pre-processing transformations as you wish. Must be from monai
# IMPORTANT: Do not remove transformations outside the "Personal Transformations" section
deterministic_transformations = Compose([
    LoadImaged(keys=["image", "seg_mask"]),  # loads NIfTI files
    EnsureChannelFirstd(keys=["image"]),
    EnsureChannelFirstd(keys="seg_mask"),
    EnsureTyped(keys="seg_mask", dtype=np.uint8),

# ----------------- Personal Transformations ----------------- #

    Orientationd(keys=["image", "seg_mask"], axcodes="RAS"),
    Spacingd(
        keys=["image", "seg_mask"],
        pixdim=(1.0, 1.0, 1.0),
        mode=("bilinear", "nearest")
    ),
    N4BiasFieldCorrectionCustomd(keys=["image"]),
    CenterSpatialCropd(keys=["image", "seg_mask"], roi_size=(160, 192, 160)),
    SpatialPadd(keys=["image","seg_mask"], spatial_size=(160,192,160)),  # pad to multiples of 16
    ScaleIntensityRanged(
        keys=["image"],
        a_min=-1000, a_max=3000,
        b_min=0.0, b_max=1.0,
        clip=True
    ),
    HistogramNormalized(keys=["image"], num_bins=256),
    NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
    AdjustContrastd(keys=["image"], gamma=1.2),

# ------------------------------------------------------------ #

    ToTensord(keys=["image", "seg_mask"])  # convert both to torch tensors
])



In [7]:
# Transformations to use when augmenting training dataset
from monai.transforms import (
    RandFlipd,
    RandAffined,
    Rand3DElasticd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandScaleIntensityd,
    RandShiftIntensityd
)

In [8]:
# Edit to change which transformations are used when augmenting training dataset
augmentation_transformations = Compose([

    # ---------------- Spatial (image + mask) ---------------- #

    # Left-right flip (safe for brain)
    RandFlipd(
        keys=["image", "seg_mask"],
        spatial_axis=0,
        prob=0.5
    ),

    # Small rotations, translations, and scaling
    RandAffined(
        keys=["image", "seg_mask"],
        prob=0.3,
        rotate_range=(0.1, 0.1, 0.1),     # ~±6°
        translate_range=(5, 5, 5),        # voxels
        scale_range=(0.1, 0.1, 0.1),
        mode=("bilinear", "nearest"),
        padding_mode="border"
    ),

    # Mild elastic deformation (simulates anatomy variance)
    Rand3DElasticd(
        keys=["image", "seg_mask"],
        prob=0.15,
        sigma_range=(5, 8),
        magnitude_range=(50, 100),
        mode=("bilinear", "nearest"),
        padding_mode="border"
    ),

    # ---------------- Intensity (image only) ---------------- #

    # Scanner noise
    RandGaussianNoised(
        keys=["image"],
        prob=0.2,
        mean=0.0,
        std=0.01
    ),

    # Slight smoothing (resolution variation)
    RandGaussianSmoothd(
        keys=["image"],
        prob=0.2,
        sigma_x=(0.5, 1.0),
        sigma_y=(0.5, 1.0),
        sigma_z=(0.5, 1.0)
    ),

    # Intensity scaling
    RandScaleIntensityd(
        keys=["image"],
        factors=0.1,
        prob=0.3
    ),

    # Intensity shift
    RandShiftIntensityd(
        keys=["image"],
        offsets=0.1,
        prob=0.3
    ),
])

In [9]:
# Initialises base dataset that you can operate on
# Loads data DYNAMICALLY
dataset_access = initialise_dataset(data_path, modalities=modalities, transformations=deterministic_transformations)
base_dataset = PersistentDataset(
    data=dataset_access.files,
    transform=dataset_access.transform,
    cache_dir="./cache"
)

In [10]:
# Initialises the augmented, training dataset
train_dataset = Dataset(
    data=base_dataset,
    transform=augmentation_transformations
)

In [18]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)

In [19]:
set_determinism(seed=42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
# We will be comparing our 2d network to monai's 3d network (?)
model = UNet(
    spatial_dims=3,
    in_channels=4,       # MRI modalities
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
).to(device)

In [21]:
loss_function = DiceLoss(to_onehot_y=False, softmax=True)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 50

img_idx = 0
mask_idx = 1

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for batch_data in train_loader:
        
        inputs = batch_data["image"].to(device)
        labels = batch_data["seg_mask"].to(device)

        optimizer.zero_grad()

        outputs = model(inputs)   # logits

        loss = loss_function(outputs, labels)
        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(train_loader):.4f}")


Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated




Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropagated
Grad. Descent done
Inputs acquired
Labels acquired
Optimiser zeroed
Outputs calculated
Loss calculated
Loss backpropaga

In [None]:
torch.save(model.state_dict(), "monai_3d_unet.pt")