# Training a 3D UNet wih MONAI



# Part 1 - Basic training loop

## Set up environment

In [None]:
!pip -q install "monai-weekly" "torch>=2.1" "tqdm"

In [None]:
import monai, torch, os, tempfile, matplotlib.pyplot as plt
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

## Download the dataset

In [None]:

from monai.apps import download_and_extract
root_dir = 'data'
resource = 'https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar'
compressed_file = os.path.join(root_dir, "spleen.tar")
download_and_extract(resource, compressed_file, root_dir)


## Prepare the dataset

In [None]:
from monai.data import Dataset
from glob import glob
images = sorted(glob(os.path.join(root_dir,"Task09_Spleen/imagesTr/*.nii.gz")))
labels = sorted(glob(os.path.join(root_dir,"Task09_Spleen/labelsTr/*.nii.gz")))
train_files = [{"image":img, "label":lbl} for img,lbl in zip(images,labels)]

print(f"Number of files: {len(train_files)}")
print(train_files[0])


In [None]:
from monai import transforms as T

def make_transform(n_pixels=64):
    # Steps to transform the file information to model inputs:

    transforms = []

    # step 1: load the data, nii.gz format to tensor
    transforms.append(
        T.LoadImaged(keys=['image', 'label'])
    )

    # step 2: Add an extra "channel" dimension (pytorch/monai convention)
    transforms.append(
        T.EnsureChannelFirstd(keys=['image', 'label'])
    )

    # step 3: Resize the data to a uniform size
    transforms.append(
        T.ResizeD(keys=['image', 'label'], spatial_size=(n_pixels, n_pixels, n_pixels//2), mode=['bilinear', 'nearest'])
    )

    # step 4: rescale the image intenisty between 0 and 1
    transforms.append(T.ScaleIntensityD(keys=['image']))

    transform = T.Compose(transforms)
    return transform


transform = make_transform(256)

In [None]:
# look at the output
data = transform(train_files[0])
fig, ax = plt.subplots(1, 2)
ax[0].imshow(data['image'][0, ..., 75])
ax[1].imshow(data['image'][0, ..., 75])
ax[1].imshow(data['label'][0, ..., 75], alpha=0.5)
fig.tight_layout()

image = data['image']
label = data['label']
print(f"pixel mean: {image.mean()}")
print(f"pixel std: {image.std()}")
print(f"Image shape: {image.shape}")
print(f"Label shape: {label.shape}")
print(f"Label values: {torch.unique(label)}")

In [None]:
# create dataset and dataloader
train_ds = train_dataset = monai.data.CacheDataset(train_files, transform=make_transform(64), cache_rate=1)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True)

batch = next(iter(train_loader))
image = batch['image']
label = batch['label']
print(f'Image shape: {image.shape}')
print(f'Label shape: {label.shape}')


## Prepare And Train the Model

In [None]:
# define a unet using monai

def get_model():
    return UNet(spatial_dims=3, in_channels=1, out_channels=2,
             channels=(16,32,64,128), strides=(2,2,2),
             num_res_units=2)

model = get_model()

n_params = sum(p.numel() for p in model.parameters())
print(f"Model has {n_params} parameters.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device to run calculations - "cuda" means gpu
print(f"Using device {device}")
model.to(device) # convert model to the correct device
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm
from collections import defaultdict


def run_training(model, optimizer, loss_fn, train_dataloader, val_dataloader=None, max_epochs=100):

    history = defaultdict(list)

    for epoch in range(max_epochs):
        model.train()
        epoch_loss = 0
        for batch_data in tqdm(train_dataloader, desc=f"Epoch {epoch}", leave=False):
            inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1}, avg loss: {epoch_loss/len(train_dataloader):.4f}")
        history['train_loss'].append(epoch_loss/len(train_dataloader))

        if val_dataloader is not None:
            model.eval()
            epoch_loss = 0
            with torch.no_grad():
                for batch_data in tqdm(val_dataloader, desc=f"Epoch {epoch}", leave=False):
                    inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
                    outputs = model(inputs)
                    loss = loss_fn(outputs, labels)
                    epoch_loss += loss.item()
            history['val_loss'].append(epoch_loss/len(val_dataloader))

    return history


history = run_training(model, optimizer, loss_fn, train_loader, max_epochs=30)

key = 'loss'
for name, metrics in history.items():
    if key not in name:
        continue
    plt.plot(metrics, label=name)
plt.legend()


In [None]:
# test inference

model.eval()
with torch.no_grad():
    sample = train_ds[0]
    input_volume = sample["image"].unsqueeze(0).to(device)
    pred = torch.argmax(model(input_volume), dim=1).cpu()[0]

import numpy as np, matplotlib.pyplot as plt
mid_slice = pred.shape[-1]//2
plt.figure(figsize=(12,4))
plt.subplot(1,3,1); plt.imshow(input_volume.cpu()[0,0,:, :, mid_slice], cmap='gray'); plt.title('Image'); plt.axis('off')
plt.subplot(1,3,2); plt.imshow(sample["label"][0, :, :, mid_slice], cmap='gray'); plt.title('Ground truth'); plt.axis('off')
plt.subplot(1,3,3); plt.imshow(pred[:, :, mid_slice], cmap='gray'); plt.title('Prediction'); plt.axis('off')
plt.show()


# Part 2 - Monitoring Model Performance

Objectives:
- Use metrics to measure model performance
- Use cross validation to test for overfitting
- Use techniques to mitigate overfitting

Step 1: implement the dice metric. Hint: use the monai implementation https://docs.monai.io/en/stable/metrics.html#mean-dice and read its documentation. Incorporate it into the training loop defined above to get the new training loop.

In [None]:
# TODO Implement this function using MONAI dice metric and the previously defined training function
def run_training_with_dice(model, optimizer, loss_fn, train_dataloader, val_dataloader=None, max_epochs=100):
    """
    Implements a training loop that includes the dice metric.
    """
    return run_training(model, optimizer, loss_fn, train_dataloader, val_dataloader, max_epochs) # placeholder


In [None]:
model = get_model()

n_params = sum(p.numel() for p in model.parameters())
print(f"Model has {n_params} parameters.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device to run calculations - "cuda" means gpu
print(f"Using device {device}")
model.to(device) # convert model to the correct device
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

history = run_training_with_dice(model, optimizer, loss_fn, train_loader, max_epochs=30)

key = 'loss'
for name, metrics in history.items():
    if key not in name:
        continue
    plt.plot(metrics, label=name)
plt.legend()

plt.figure()
key = 'dice'
for name, metrics in history.items():
    if key not in name:
        continue
    plt.plot(metrics, label=name)
plt.legend()

Step 2. Because we don't have a validation set yet, we don't know how well the model will generalize to new data. It could be overfitting to the training set! Create a validation set and run the training loop again. Compare the validation metrics to the training metrics. If the training loss is lower than the validation loss or the training dice is higher than the validation dice, this indicates overfitting.

In [None]:
# Helpful demo: sklearn train_test_split function
# sklearn train_test_split demo
import sklearn
new_train_files, val_files = sklearn.model_selection.train_test_split(train_files)

print(len(new_train_files))
print(new_train_files)
print(len(val_files))
print(val_files)

In [None]:
new_train_loader = None
val_loader = None
# TODO - implement the validation loader. Hint: use the sklearn train_test_split function, then follow the "Prepare dataset"
# recipe with the two different sets of files to create two datasets and dataloaders.

Run the training again with the validation loader. We want to improve the validation metrics (lower the validation loss and increase the validation dice) as much as possible while avoiding overfitting. When did overfitting occur? What were the best validation metrics?

In [None]:
model = get_model()

n_params = sum(p.numel() for p in model.parameters())
print(f"Model has {n_params} parameters.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device to run calculations - "cuda" means gpu
print(f"Using device {device}")
model.to(device) # convert model to the correct device
loss_fn = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

history = run_training_with_dice(model, optimizer, loss_fn, new_train_loader, val_dataloader=val_loader, max_epochs=100)

key = 'loss'
for name, metrics in history.items():
    if key not in name:
        continue
    plt.plot(metrics, label=name)
plt.legend()

plt.figure()
key = 'dice'
for name, metrics in history.items():
    if key not in name:
        continue
    plt.plot(metrics, label=name)
plt.legend()


Step 3. Mitigating overfitting

You are now free to experiment with modifying any part of the training pipeline. The goal is to reduce overfitting and improve the best validation dice. Suggestions and hints are outlined below.

In [None]:
# Suggestion 1 - Modifying the model architecture.

# read the monai documentation about UNet and edit the model configuration and experiment with the results.
# feel free to copy the documentation into a language model for advice on which configuration changes could mitigate overfitting!
def get_model():
    model = UNet(spatial_dims=3, in_channels=1, out_channels=2,
                channels=(16,32,64,128), strides=(2,2,2),
                num_res_units=2)
    return model

In [None]:
# Suggestion 2 - Modifying the optimizer.

# read the documentation about torch.optim.AdamW with the help of a language model. See if configuration changes could mitigate overfitting to improve performance
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
# Suggestion 3 - Modifying the data pipeline


# Data augmentation is a common technique to reduce overfitting by creating more variations in the data being used to train the model.
# Read the monai documentation for transforms such as https://docs.monai.io/en/1.3.0/transforms.html#randadjustcontrastd and https://docs.monai.io/en/1.3.0/transforms.html#randcoarseshuffled
# to see if adding these into the data pipeline can improve outcomes.
# ! WARNING ! be sure to disable them when you are creating the transform for the validation set by setting `is_train=False`

def make_transform(n_pixels=64, is_train=True):
    # Steps to transform the file information to model inputs:

    transforms = []

    # step 1: load the data, nii.gz format to tensor
    transforms.append(
        T.LoadImaged(keys=['image', 'label'])
    )

    # step 2: Add an extra "channel" dimension (pytorch/monai convention)
    transforms.append(
        T.EnsureChannelFirstd(keys=['image', 'label'])
    )

    if is_train:
        # do any data augmentations here <-----------------
        ...


    # step 3: Resize the data to a uniform size
    transforms.append(
        T.ResizeD(keys=['image', 'label'], spatial_size=(n_pixels, n_pixels, n_pixels//2), mode=['bilinear', 'nearest'])
    )

    # step 4: rescale the image intenisty between 0 and 1
    transforms.append(T.ScaleIntensityD(keys=['image']))

    transform = T.Compose(transforms)
    return transform
