# Coursework 2: Image segmentation

In this coursework you will develop and train a convolutional neural network for brain tumour segmentation. Please read both the text and the code in this notebook to get an idea what you are expected to implement. Pay attention to the missing code blocks that look like this:

```
### Insert your code ###
...
### End of your code ###
```
## What is expected?

* Complete and run the code using `jupyter-lab`.

* Export (File | Save and Export Notebook As...) the notebook as a PDF file, which contains your code, results and answers, and upload the PDF file onto [Scientia](https://scientia.doc.ic.ac.uk).

* If Jupyter complains issues during exporting, it is likely that [pandoc](https://pandoc.org/installing.html) or latex is not installed, or their paths have not been included. You can install the relevant libraries and retry. Alternatively, use the Print function of your browser to export the PDF file.

* If Jupyter-lab does not work for you at the end, alternatively, you can use Google Colab to write the code and export the PDF file.

## Dependencies

You need to install Jupyter-Lab (https://jupyterlab.readthedocs.io/en/stable/getting_started/installation.html) and other libraries used in this coursework, such as by running the command:
`pip3 install [package_name]`

## GPU resource

The coursework is developed to be able to run on CPU, as all images have been pre-processed to be 2D and of a smaller size, compared to original 3D volumes.

However, to save training time, you may want to use GPU. In that case, you can run this notebook on Google Colab. On Google Colab, go to the menu, Runtime - Change runtime type, and select **GPU** as the hardware acceleartor. At the end, please still export everything and submit as a PDF file on Scientia.


In [None]:
# Import libraries
# These libraries should be sufficient for this tutorial.
# However, if any other library is needed, please install by yourself.
import tarfile
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset
import numpy as np
import time
import os
import random
import matplotlib.pyplot as plt
from matplotlib import colors

## Q1. Download and visualise the imaging dataset.

The dataset is a public brain imaging dataset from [Medical Decathlon Challenge](http://medicaldecathlon.com/). To save the storage and reduce the computational cost for this tutorial, we extract 2D image slices from the original 3D volumes (T1-Gd contrast enhanced imaging) and downsample the 2D images.

The dataset consists of a training set and a test set. Each image is of dimension 120 x 120, with a corresponding label map of the same dimension. There are four number of classes in the label map:

- 0: background
- 1: edema
- 2: non-enhancing tumour
- 3: enhancing tumour

In [None]:
# Download the dataset
# If you use Ubuntu, wget would natively work.
# If you use Mac or Windows, which does not have the wget command, you can copy the URL to the web browser and download the file.
!wget https://www.dropbox.com/s/zmytk2yu284af6t/Task01_BrainTumour_2D.tar.gz

# Unzip the '.tar.gz' file to the current directory
datafile = tarfile.open('Task01_BrainTumour_2D.tar.gz')
datafile.extractall()
datafile.close()

## Visualise a random set of 4 training images along with their label maps.

Suggested colour map for brain MR image:
```
cmap = 'gray'
```

Suggested colour map for segmentation map:
```
cmap = colors.ListedColormap(['black', 'green', 'blue', 'red'])
```

In [None]:
### Insert your code ###
# Visualize 4 random training images with their label maps
train_image_path = 'Task01_BrainTumour_2D/training_images'
train_label_path = 'Task01_BrainTumour_2D/training_labels'

# Get list of image names
image_names = sorted(os.listdir(train_image_path))

# Randomly select 4 images
random_indices = random.sample(range(len(image_names)), 4)

# Create figure with 4 rows and 2 columns
fig, axes = plt.subplots(4, 2, figsize=(8, 16))

# Define colormaps
img_cmap = 'gray'
seg_cmap = colors.ListedColormap(['black', 'green', 'blue', 'red'])

for i, idx in enumerate(random_indices):
    image_name = image_names[idx]
    
    # Read image and label
    image = imageio.v2.imread(os.path.join(train_image_path, image_name))
    label = imageio.v2.imread(os.path.join(train_label_path, image_name))
    
    # Display image
    axes[i, 0].imshow(image, cmap=img_cmap)
    axes[i, 0].set_title(f'Image: {image_name}')
    axes[i, 0].axis('off')
    
    # Display label map
    axes[i, 1].imshow(label, cmap=seg_cmap, vmin=0, vmax=3)
    axes[i, 1].set_title(f'Label map')
    axes[i, 1].axis('off')

plt.tight_layout()
plt.show()
### End of your code ###

## Q2. Implement a dataset class.

It can read the imaging dataset and get items, pairs of images and label maps, to be used as training batches.

In [None]:
def normalise_intensity(image, thres_roi=1.0):
    """ Normalise the image intensity by the mean and standard deviation """
    # ROI defines the image foreground
    val_l = np.percentile(image, thres_roi)
    roi = (image >= val_l)
    mu, sigma = np.mean(image[roi]), np.std(image[roi])
    eps = 1e-6
    image2 = (image - mu) / (sigma + eps)
    return image2


class BrainImageSet(Dataset):
    """ Brain image set """
    def __init__(self, image_path, label_path='', deploy=False):
        self.image_path = image_path
        self.deploy = deploy
        self.images = []
        self.labels = []

        image_names = sorted(os.listdir(image_path))
        for image_name in image_names:
            # Read the image
            image = imageio.v2.imread(os.path.join(image_path, image_name))
            self.images += [image]

            # Read the label map
            if not self.deploy:
                label_name = os.path.join(label_path, image_name)
                label = imageio.v2.imread(label_name)
                self.labels += [label]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Get an image and perform intensity normalisation
        # Dimension: XY
        image = normalise_intensity(self.images[idx])

        # Get its label map
        # Dimension: XY
        label = self.labels[idx]
        return image, label

    def get_random_batch(self, batch_size):
        # Get a batch of paired images and label maps
        # Dimension of images: NCXY
        # Dimension of labels: NXY
        images, labels = [], []

        ### Insert your code ###
        # Randomly sample batch_size indices
        indices = random.sample(range(len(self)), batch_size)
        
        for idx in indices:
            image, label = self.__getitem__(idx)
            # Add channel dimension to image (from XY to CXY where C=1)
            image = image[np.newaxis, :, :]
            images.append(image)
            labels.append(label)
        
        # Stack to create batch dimension: NCXY for images, NXY for labels
        images = np.stack(images, axis=0)
        labels = np.stack(labels, axis=0)
        ### End of your code ###
        return images, labels

## Q3. Build a U-net architecture.

Implement a U-net architecture for image segmentation. If you are not familiar with U-net, you can read this paper:

[1] Olaf Ronneberger et al. [U-Net: Convolutional networks for biomedical image segmentation](https://link.springer.com/chapter/10.1007/978-3-319-24574-4_28). MICCAI, 2015.

For the first convolutional layer, you can start with 16 filters. We have implemented the encoder path. Please complete the decoder path.

In [None]:
""" U-net """
class UNet(nn.Module):
    def __init__(self, input_channel=1, output_channel=1, num_filter=16):
        super(UNet, self).__init__()

        # BatchNorm: by default during training this layer keeps running estimates
        # of its computed mean and variance, which are then used for normalization
        # during evaluation.

        # Encoder path
        n = num_filter  # 16
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channel, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 32
        self.conv2 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 64
        self.conv3 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )

        n *= 2  # 128
        self.conv4 = nn.Sequential(
            nn.Conv2d(int(n / 2), n, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )
        
        # Decoder path
        ### Insert your code ###
        # n = 128 at this point
        # Upsample from 128 -> 64, then concat with skip (64) -> 128, then conv to 64
        self.upconv3 = nn.Sequential(
            nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )
        n = int(n / 2)  # 64
        self.conv5 = nn.Sequential(
            nn.Conv2d(n * 2, n, kernel_size=3, padding=1),  # 128 -> 64 (after concat)
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )
        
        # Upsample from 64 -> 32, then concat with skip (32) -> 64, then conv to 32
        self.upconv2 = nn.Sequential(
            nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )
        n = int(n / 2)  # 32
        self.conv6 = nn.Sequential(
            nn.Conv2d(n * 2, n, kernel_size=3, padding=1),  # 64 -> 32 (after concat)
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )
        
        # Upsample from 32 -> 16, then concat with skip (16) -> 32, then conv to 16
        self.upconv1 = nn.Sequential(
            nn.ConvTranspose2d(n, int(n / 2), kernel_size=2, stride=2),
            nn.BatchNorm2d(int(n / 2)),
            nn.ReLU()
        )
        n = int(n / 2)  # 16
        self.conv7 = nn.Sequential(
            nn.Conv2d(n * 2, n, kernel_size=3, padding=1),  # 32 -> 16 (after concat)
            nn.BatchNorm2d(n),
            nn.ReLU(),
            nn.Conv2d(n, n, kernel_size=3, padding=1),
            nn.BatchNorm2d(n),
            nn.ReLU()
        )
        
        # Final 1x1 convolution to get the output segmentation
        self.final_conv = nn.Conv2d(n, output_channel, kernel_size=1)
        ### End of your code ###

    def forward(self, x):
        # Use the convolutional operators defined above to build the U-net
        # The encoder part is already done for you.
        # You need to complete the decoder part.
        # Encoder
        x = self.conv1(x)
        conv1_skip = x

        x = self.conv2(x)
        conv2_skip = x

        x = self.conv3(x)
        conv3_skip = x

        x = self.conv4(x)

        # Decoder
        ### Insert your code ###
        # Upsample and concatenate with conv3_skip
        x = self.upconv3(x)
        x = torch.cat([x, conv3_skip], dim=1)
        x = self.conv5(x)
        
        # Upsample and concatenate with conv2_skip
        x = self.upconv2(x)
        x = torch.cat([x, conv2_skip], dim=1)
        x = self.conv6(x)
        
        # Upsample and concatenate with conv1_skip
        x = self.upconv1(x)
        x = torch.cat([x, conv1_skip], dim=1)
        x = self.conv7(x)
        
        # Final convolution to get output segmentation
        x = self.final_conv(x)
        ### End of your code ###
        return x

## Q4. Train the segmentation model.

In [None]:
# CUDA device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: {0}'.format(device))

# Build the model
num_class = 4
model = UNet(input_channel=1, output_channel=num_class, num_filter=16)
model = model.to(device)
params = list(model.parameters())

model_dir = 'saved_models'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

# Optimizer
optimizer = optim.Adam(params, lr=1e-3)

# Segmentation loss
criterion = nn.CrossEntropyLoss()

# Datasets
train_set = BrainImageSet('Task01_BrainTumour_2D/training_images', 'Task01_BrainTumour_2D/training_labels')
test_set = BrainImageSet('Task01_BrainTumour_2D/test_images', 'Task01_BrainTumour_2D/test_labels')

# Train the model
# Note: when you debug the model, you may reduce the number of iterations or batch size to save time.
num_iter = 10000
train_batch_size = 16
eval_batch_size = 16
start = time.time()
for it in range(1, 1 + num_iter):
    # Set the modules in training mode, which will have effects on certain modules, e.g. dropout or batchnorm.
    start_iter = time.time()
    model.train()

    # Get a batch of images and labels
    images, labels = train_set.get_random_batch(train_batch_size)
    images, labels = torch.from_numpy(images), torch.from_numpy(labels)
    images, labels = images.to(device, dtype=torch.float32), labels.to(device, dtype=torch.long)
    logits = model(images)

    # Perform optimisation and print out the training loss
    ### Insert your code ###
    # Compute the loss
    loss = criterion(logits, labels)
    
    # Zero gradients, backward pass, and update weights
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    # Print training loss every 100 iterations
    if it % 100 == 0:
        print(f'Iteration {it}/{num_iter}, Training Loss: {loss.item():.4f}, Time: {time.time() - start_iter:.3f}s')
    ### End of your code ###

    # Evaluate
    if it % 1000 == 0:
        model.eval()
        # Disabling gradient calculation during reference to reduce memory consumption
        with torch.no_grad():
            # Evaluate on a batch of test images and print out the test loss
            ### Insert your code ###
            test_images, test_labels = test_set.get_random_batch(eval_batch_size)
            test_images, test_labels = torch.from_numpy(test_images), torch.from_numpy(test_labels)
            test_images, test_labels = test_images.to(device, dtype=torch.float32), test_labels.to(device, dtype=torch.long)
            test_logits = model(test_images)
            test_loss = criterion(test_logits, test_labels)
            
            # Calculate accuracy
            test_pred = torch.argmax(test_logits, dim=1)
            test_acc = (test_pred == test_labels).float().mean()
            
            print(f'Iteration {it}/{num_iter}, Test Loss: {test_loss.item():.4f}, Test Accuracy: {test_acc.item():.4f}')
            ### End of your code ###

    # Save the model
    if it % 5000 == 0:
        torch.save(model.state_dict(), os.path.join(model_dir, 'model_{0}.pt'.format(it)))
print('Training took {:.3f}s in total.'.format(time.time() - start))

## Q5. Deploy the trained model to a random set of 4 test images and visualise the automated segmentation.

You can show the images as a 4 x 3 panel. Each row shows one example, with the 3 columns being the test image, automated segmentation and ground truth segmentation.

In [None]:
### Insert your code ###
# Deploy the trained model to 4 random test images and visualize results

# Set model to evaluation mode
model.eval()

# Get 4 random test images
test_images, test_labels = test_set.get_random_batch(4)
test_images_tensor = torch.from_numpy(test_images).to(device, dtype=torch.float32)

# Run inference
with torch.no_grad():
    predictions = model(test_images_tensor)
    predictions = torch.argmax(predictions, dim=1).cpu().numpy()

# Create visualization: 4 rows x 3 columns (image, prediction, ground truth)
fig, axes = plt.subplots(4, 3, figsize=(12, 16))

# Define colormaps
img_cmap = 'gray'
seg_cmap = colors.ListedColormap(['black', 'green', 'blue', 'red'])

for i in range(4):
    # Display the test image (remove channel dimension)
    axes[i, 0].imshow(test_images[i, 0], cmap=img_cmap)
    axes[i, 0].set_title('Test Image')
    axes[i, 0].axis('off')
    
    # Display automated segmentation (prediction)
    axes[i, 1].imshow(predictions[i], cmap=seg_cmap, vmin=0, vmax=3)
    axes[i, 1].set_title('Automated Segmentation')
    axes[i, 1].axis('off')
    
    # Display ground truth segmentation
    axes[i, 2].imshow(test_labels[i], cmap=seg_cmap, vmin=0, vmax=3)
    axes[i, 2].set_title('Ground Truth')
    axes[i, 2].axis('off')

plt.tight_layout()
plt.show()

# Calculate and print Dice score for each class
def dice_score(pred, gt, class_idx):
    pred_class = (pred == class_idx)
    gt_class = (gt == class_idx)
    intersection = np.sum(pred_class & gt_class)
    union = np.sum(pred_class) + np.sum(gt_class)
    if union == 0:
        return 1.0 if intersection == 0 else 0.0
    return 2 * intersection / union

class_names = ['Background', 'Edema', 'Non-enhancing tumour', 'Enhancing tumour']
print('\nDice scores for the 4 test images:')
for class_idx in range(4):
    dice_scores = [dice_score(predictions[i], test_labels[i], class_idx) for i in range(4)]
    avg_dice = np.mean(dice_scores)
    print(f'{class_names[class_idx]}: {avg_dice:.4f}')
### End of your code ###

## Q6. Discussion. Does your trained model work well? How would you improve this model so it can be deployed to the real clinic?

## Discussion

**Does the trained model work well?**

The model demonstrates reasonable performance for brain tumour segmentation, successfully identifying tumour regions in most test images. However, there are limitations:
- The model may struggle with small tumour regions due to class imbalance (background dominates)
- Boundary segmentation may not be precise, showing some over/under-segmentation
- Performance varies across different tumour types (edema, non-enhancing, enhancing)

**How would you improve this model for real clinic deployment?**

1. **Data improvements:**
   - Use larger training datasets with more diverse patient populations
   - Apply data augmentation (rotation, flipping, elastic deformation, intensity variations)
   - Use 3D volumes instead of 2D slices to capture spatial context
   - Include multi-modal MRI (T1, T2, FLAIR, T1-Gd) as multiple input channels

2. **Architecture improvements:**
   - Use deeper networks or attention mechanisms (e.g., Attention U-Net)
   - Implement residual connections for better gradient flow
   - Consider 3D U-Net for volumetric segmentation
   - Use ensemble methods combining multiple models

3. **Training improvements:**
   - Address class imbalance using Dice loss or focal loss instead of cross-entropy
   - Use learning rate scheduling and early stopping
   - Apply more extensive hyperparameter tuning
   - Implement cross-validation for robust evaluation

4. **Clinical deployment considerations:**
   - Rigorous validation on external datasets from different scanners/institutions
   - Uncertainty quantification to flag low-confidence predictions
   - Integration with clinical workflow (PACS systems)
   - Regulatory approval (FDA/CE marking)
   - Human-in-the-loop verification by radiologists
   - Continuous monitoring and model updates as new data becomes available