# RadNet AI for Optimising Radiotherapy Outcomes Workshop - Coding demonstration

This notebook will show, briefly, how to build an autosegmentation model for thoracic OARs using pytorch and pytorch-lightning. We will be using some open data from [The Cancer Imaging Archive (TCIA)](https://www.cancerimagingarchive.net/), originally used for a [AAPM challenge](http://www.autocontouringchallenge.org/). This dataset contains 60 patients, each of which has five OARs segmented.

To handle the data, we will use [pydicom](https://github.com/pydicom/pydicom) to load slices and ideas from [dicom-contour](https://github.com/KeremTurgutlu/dicom-contour) to convert RTSTRUCT objects into masks.

We will be using a suite of pre-built pytorch segmentation models in the excellent [segmentation-models](https://github.com/qubvel/segmentation_models.pytorch) package. This package simplifies the building of a pretrained segmentation network in 2D. We will use a 2D approach, looping over slices in the data to segment 3D organs.

Pytorch can be quite intimidating, but is very powerful when you get to grips with it. In the interests of simplicity, we will use a wrapper around pytorch called [pytorch-lightning](https://pytorch-lightning.readthedocs.io/en/latest/). Lightning separates out the different bits of ML, allowing you to write a bit less boilerplate code, and letting us very quickly and easily use best-practise methods to train our models.


# Overview
The steps in this notebook make the following steps:

0. Install prerequisites and set up
1. Load DICOM data containing CT and segmentation and convert to numpy arrays
2. Define some preprocessing and apply it to the CT slices
3. Create a segmentation model, using a library to make a pre-trained model for our segmentation task
4. Train a the model to reproduce the training examples
5. Test the model against the testing data and produce the AAPM competition ranking score

# How to use colab & jupyter notebooks
If you're new to colab and/or jupyter notebooks, here are some tips on how they work.

## Colab
Colab is a free ML playground from google. It allows you free access to limited resources, including a GPU and some storage space. Even though the limits are quite small: ~30GB disk, 12GB RAM & a random GPU from K80 up to P100, you can do some pretty cool stuff with it. You will need a Google account to sign into it.

For us, we are going to be training a Convolutional Neural Network (CNN), so we need to get a GPU. To do this, click "Runtime" in the menu across the top of the colab page, then select "Change Runtime Type". From the dropdown, select GPU and click save. The runtime will then reboot and you should have a GPU. To find out what you got, run the cell below this text.

## Jupyter
Jupyter is a tool for running python in a notebook form. A notebook is simply a document containing code and accompanying text describing/explaning the code. You're reading one right now!

Notebooks are divided into cells which, for the most part, come in two flavours - Markdown and code. A markdown cell is where you can type words to explain the code. Try double clicking on this text, and you should be shown the markdown that created it. 

Code cells contain python code. There are a couple of things to bear in mind about notebooks that differ from normal python scripts:
- Notebook cells can be run in any order
- The output of any cell is available in any other cell (as if it were all global scope in python)
- Typos can really screw you over. If you make a typo in a variable name, that variable still exists, and anywhere where you made the same typo will use the old variable instead of the new one. This has personally led to at least three hours debugging that could have been saved by being able to spell.
- Default plotting behaviour is to just give you a picture with no interactivity. We can override it though

All cells are executed by either clicking the play button at the top left corner of the cell, or by clicking in it and pressing ctrl+enter. You can also press shift+enter, which will run the curent cell and move to the next.

### Jupyter escapes and magic
You will see a few cells with lines starting in either an exclamation mark (!) or a percentage sign (%). These lines are called escapes and magics. An escape simply makes jupyter run the command after it in a bash shell (or cmd if you're on windows), this allows us to do stuff like run wget to download data.

Magics do things to alter the state of jupyter, for example by turning matplotlib interactivity on, or enabling the browser-in-browser that allows us to use tensorboard monitoring. The most important thing to note about magics is that they can't have a comment after them. You can look up some jupyter magics [here](https://ipython.readthedocs.io/en/stable/interactive/magics.html).


Now you know how to drive this notebook, let's start running some stuff!

In [None]:
## Find out what GPU we got (and make sure we actually have one!)
!nvidia-smi

Because of the volume of data and the time we have, I've saved the output of each step (where applicable). This is because colab has a RAM limit of 12GB and doing some of this requires more than that and will crash the runtime. The code for each step is in the notebook, but we will be doing "Blue Peter" here's one I prepared earlier at pretty much every step.

The cell below downloads all the necessary data and places it on the colab machine. Run it, and then you can see how the data was manipulated in the functions, but don't have to run it if we don't have time.

In [None]:
## Download all the data! This may take a little while...

## Processed data
try:
    import google.colab
    IN_COLAB = True
    !wget https://www.dropbox.com/s/la05h49y9ths7x3/autoseg_data.tar.gz?dl=0 -O /content/AllProcessedData.tar
    !tar -xf /content/AllProcessedData.tar -C /content/
    !rm /content/AllProcessedData.tar
except:
    IN_COLAB = False
    !wget https://www.dropbox.com/s/la05h49y9ths7x3/autoseg_data.tar.gz?dl=0 -O ./AllProcessedData.tar
    !tar -xf ./AllProcessedData.tar -C ./
    !rm AllProcessedData.tar

# utilities python file
if IN_COLAB:
    !wget https://www.dropbox.com/s/rng7h9mgkwaolt8/utils.py?dl=0 -O /content/utils.py
else:
    !wget https://www.dropbox.com/s/rng7h9mgkwaolt8/utils.py?dl=0 -O ./utils.py

## 0. Install prerequisites

Here we install the pytorch flavour of the segmentation-models library, along with pydicom and pytorch-lightning

In [None]:
if IN_COLAB:
    %pip install git+https://github.com/qubvel/segmentation_models.pytorch
    %pip install pytorch_lightning tqdm ipympl

## 0. Set up monitoring and enable matplotlib notebook interactivity

In [None]:
%matplotlib inline
import tensorboard
%load_ext tensorboard

## 0. Load required libraries

In [None]:
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt 
import pytorch_lightning as pl
import albumentations as A
import numpy as np
# import pydicom
import torch
import os
from utils import getFiles
from os.path import join
import pickle
## These are the structures defined in the FLARE data
structure_names = ['Body', 'Liver', 'Kidneys', 'Spleen', 'Pancreas']
## See if we're in colab...
if IN_COLAB:
    datapath = "/content/FLARE_data/"
else:
    datapath = "/home/ed/autoseg_workshop_2023/FLARE_data/"  ## <---- You will need to change this if running locally

At this point, we may start to skip running things during the live demo. This is purely because of the time constraints we have. I will still go through each cell, and we will look at the output of each cell when I load the Blue Peter packs.



# 1. Load and visualise the data
Earlier, we downloaded and decompressed a load of data for us to use. Part of that is the raw DICOM from the TCIA, the cells below will process that data into numpy arrays, which is what we need to do our ML with.

The data is provided in DICOM format, as most RT data will be when we want to use it for ML. In the following cell, I define some functions that will help us load the DICOM data and convert it into numpy arrays so that we can use it to train an ML model. Each function is described in its own docstring, but the broad idea is to get the UID for every slice in a CT image, then look for any structure that references that UID. Then we transform the coordinates of the contour points from the DICOM frame of reference into image pixels, and burn those pixels into a mask. Then we can use a binary hole filling to create a solid 2D mask. This process is repeated for every slice that has a contour on, for every ROI we want to create a full 3D mask. Note that I set the pixel value based on the index of the ROI - this will become the pixel's class label later on.

We have to do this now because of the way segmentation models work. 'Classic' CNNs classify an image into one of several classes (e.g. dog, bird, cat etc). This classification is across the whole image, and if there were to be an image containing both a dog and cat, it would be difficult to classify that image, since the network can "see" both things. The next level of complexity is object detection networks; these work by drawing a bounding box around each object in the image - this means they can handle images with more than one class in. However, when we're doing radiotherapy, or lots of other tasks, we need to know exactly which bits of the image are what object - this is where segmentation, or semantic segmentation comes in. Semantic segmentation gives a class label to every pixel in the image, allowing it to accurately track the edges of organs and other things.

To train a semantic segmentation model, we need labels for every pixel in the image. These are derived from the DICOM RTSTRUCT file, but we have convert contours (a series of points in space) to masks (images with 0 for background, 1 for foreground). Lots of the code in the cell below is inspired by [dicom-contour](https://github.com/KeremTurgutlu/dicom-contour), but doesn't make the assumption that filename = SOPInstanceUID. This cell loops over all the patients and their available contours and links up the contour with the correct image slice, converting it to a mask on the way. This is what we then use to train our network.

Load the preprocessed data, run the next cell

In [None]:
## Get the training data
train_datapath = join(datapath, "train")
fnames = sorted(getFiles(join(train_datapath, 'ims')))
train_ct_slices = np.zeros((len(fnames), 256, 256), dtype=np.float32)
train_mask_slices = np.zeros((len(fnames), 256, 256), dtype=np.float32)


for fdx, fname in enumerate(fnames):
    train_ct_slices[fdx] = np.load(join(train_datapath, "ims", fname))
    train_mask_slices[fdx] = np.load(join(train_datapath, "masks", fname))

# load pixel spacing
with open(join(train_datapath, "spacings.pkl"), "rb") as f:
    train_pixel_sizes = pickle.load(f)


# Get the test data
test_datapath = join(datapath, "test")
fnames = sorted(getFiles(join(test_datapath, 'ims')))
test_ct_slices = np.zeros((len(fnames), 256, 256), dtype=np.float32)
test_mask_slices = np.zeros((len(fnames), 256, 256), dtype=np.float32)

for fdx, fname in enumerate(fnames):
    test_ct_slices[fdx] = np.load(join(test_datapath, "ims", fname))
    test_mask_slices[fdx] = np.load(join(test_datapath, "masks", fname))

# load pixel spacing
with open(join(test_datapath, "spacings.pkl"), "rb") as f:
    test_pixel_sizes = pickle.load(f)

print(train_ct_slices.shape)
print(train_mask_slices.shape)
print(test_ct_slices.shape)
print(test_mask_slices.shape)

# 2. Preprocessing

Preprocessing your data is an extremely important part of machine learning, and so I'm going to do a little bit here. preprocessing is used to standardise the images and, especially in machine learning, to compress their intensities down to a given range (usually 0-1). 

The preprocessing I will write in the next cell is the simplest I can come up with that will still do the job. We will standardise the images by applying a level/window transformation and standardising the output in the range 0-255. The slices are then used as if they were normal grayscale images, and the preprocessing/augmentation pipeline takes care of the rest. 

First we define a window/level like function and use it to set a mediastinal window on the data - most of what we're trying to visualise is in that region, so this should be ok.

In [None]:
def window_level(data, window=350, level=50):
    """
    Apply a window and level transformation to CT slices. 

    The default values are taken taken from https://radiopaedia.org/articles/windowing-ct?lang=gb and are recommended for visualising the mediastinum
    
    The returned array will be NxHxWx3, as we expand the array into 3 channels. Values will be in the range 0-255 and type will be uint8 to mimic a 'normal' image
    """
    ## calculate high & low edges of level & window
    low_edge  = level - (window//2)
    high_edge = level + (window//2)
    ## use np.clip to clip into that level/window, then adjust to range 0 - 255 and convert to uint8
    windowed_data = (((np.clip(data, low_edge, high_edge) - low_edge)/window) * 255).astype(np.uint8)
    

    ## repeat the array in the last axis to make a 3 channel image
    # windowed_data = np.repeat(windowed_data, 3, axis=-1)

    return windowed_data

Now we can apply this transformation to the data we loaded from the DICOM. This doesn't take too long, but there is a pre-processed npz file if you're impatient.

In [None]:
## Apply the preprocessing 
window_levelled_slices_train = window_level(train_ct_slices)
window_levelled_slices_test = window_level(test_ct_slices)

## If you run out of RAM, you can save memory by deleting the original array - will need to re-load it if we change something
# del train_annotated_ct_slices
# del test_annotated_ct_slices

print(window_levelled_slices_train.shape)
print(window_levelled_slices_test.shape)

## Sanity checking
It is a good idea to periodically check that your data actually makes sense. I call these sanity checks, and they are as simple as just plotting the CT slice with the masks overlaid and making sure they look more or less lined up. Let's do this quickly now

In [None]:
fig, ax = plt.subplots()
ax.imshow(window_levelled_slices_train[47,...], cmap='Greys_r')
ax.imshow(train_mask_slices[47,...].squeeze(), alpha=0.5, cmap='viridis', vmax=5)
ax.invert_yaxis()
plt.show()

If you've been paying attention, you will know that we have ~5000 images to train on. That's a lot. To make training a bit more tractable, I will now randomly select ~1500 training examples and ~500 validation ones. Ideally, you would just use the whole dataset, but we will either run out of memory or time if we do. We set the numpy random seed to a known value so everyone should get roughly the same results, then re-seed randomly afterwards so nothing else is affected.

In [None]:
np.random.seed(1234)
subset_indices = np.random.randint(0, window_levelled_slices_train.shape[0], size=1500)

wl_slice_subset_train = window_levelled_slices_train[subset_indices[0:1000]]
mask_subset_train = train_mask_slices[subset_indices[0:1000]]
spacings_subset_train = np.array(list(train_pixel_sizes.values()))[subset_indices[0:1000]]

wl_slice_subset_val = window_levelled_slices_train[subset_indices[-500:]]
mask_subset_val =  train_mask_slices[subset_indices[-500:]]
spacings_subset_val = np.array(list(train_pixel_sizes.values()))[subset_indices[-500:]]

# del window_levelled_slices_train
# del train_mask_slices

np.random.seed()

In [None]:
## Another sanity check to be sure that did what we expected...
fig, ax = plt.subplots()
ax.imshow(wl_slice_subset_train[7,...], cmap='Greys_r')
ax.imshow(mask_subset_train[7,...].squeeze(), alpha=0.5, cmap='viridis', vmax=5)
ax.invert_yaxis()
print(wl_slice_subset_train.shape)
plt.show()

Now we can build the data loading, augmentation and normalisation pipeline. This is a little more involved in pytorch than in keras for example, but still isn't too difficult.

In the cell below, we create a subclass of the pytorch Dataset object, specific to our task. This class is made from three arguments - the array of images, the array of masks and a set of transformations we wish to apply. The two arrays are easy, we already have them ready to go, but the transformations may take a little bit of thinking.

These transformations are where we can introduce some data augmentation. Recall that data augmentation is the process of applying random transformations to our data to create "new" synthetic data to train our model with. We said at the beginning we would only use a little bit of data augmentation, namely horizontal flipping and a little bit of rotation. We will set that up in the next cell.

There are also some transformations that are more or less mandatory. For example, pytorch's pretrained models expect images to be of type float, and have a specific mean and standard deviation (derived from the imagenet dataset). We apply a transformation that handles normalisaing to the imagenet mean, and converting the array into a pytorch tensor so it can be sent through the model. Because we're working on a single channel, we normalise to the mean of means, and mean standard deviation across the RGB channels (This will hopefully make more sense in the code.)

We will use a library called albumentations to handle our augmentations. It is very fast and easy to use, but only works for 2D images. For 3D augmentations, there is a library called kornia which can do some augmentations.

It is important to note that we use different pipelines for the training and validation data. For training, we can do whatever we like to make the data go as far as possible, but when validating, we're meant to be getting an idea of the network's performance on 'real' data, so we should not do any augmentation. This is easy to do with the way things are set up.

In [None]:
## Define a subclass of Dataset that handles our image-mask pair loaded from arrays
## This is the bare minimum example!
class LCTSCDataGen(torch.utils.data.Dataset):
  def __init__(self, image_array, mask_array, transform, spacings):
    super().__init__()
    self.image_array = image_array
    self.mask_array = mask_array
    self.transform = transform
    self.spacings = spacings

  def __len__(self):
    return self.image_array.shape[0]

  def __getitem__(self, idx):
    image = self.image_array[idx,...]
    mask = self.mask_array[idx, ...]
    if self.transform is not None:
      transformed = self.transform(image=image, mask=mask)
      image = transformed['image']
      mask = transformed['mask']
    spacing = self.spacings[idx]
    return image[np.newaxis,...], mask, spacing

## Now create the augmentation pipeline

train_transforms = A.Compose([
    A.Rotate(5),
    A.Normalize(mean=(np.mean([0.485, 0.456, 0.406])), std=(np.mean([0.229, 0.224, 0.225]))) ## Note mean of means, mean of stds
])

## valdation pipeline just does normalisation and conversion to tensor
val_transforms = A.Compose([
    A.Normalize(mean=(np.mean([0.485, 0.456, 0.406])), std=(np.mean([0.229, 0.224, 0.225]))) ## Note mean of means, mean of stds
])



## Now create some datasets and dataloaders
train_dataset = LCTSCDataGen(wl_slice_subset_train, mask_subset_train, train_transforms, spacings_subset_train)
val_dataset = LCTSCDataGen(wl_slice_subset_val, mask_subset_val, val_transforms, spacings_subset_val)

## Create dataloaders from these datasets
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8)

print(mask_subset_train.shape, wl_slice_subset_train.shape)

In [None]:
## One more sanity check...
train_img, train_msk, _ = next(iter(train_dataloader))

fig, ax = plt.subplots()
ax.imshow(train_img.numpy()[3,...].squeeze(), cmap='Greys_r')
ax.imshow(train_msk.numpy()[3,...].squeeze(), alpha=0.5, cmap='viridis', vmax=5)
ax.invert_yaxis()
plt.show()


# 3. Creating the Segmentation model

For simplicity, we are going to use a python library of pre-built segmentation networks. The library is called `segmentation_models` and I highly reccomend you have a look at the [github page](https://github.com/qubvel/segmentation_models.pytorch). We already imported the library at the top of the notebook, so we can use it straight away here.

We will be using a pre-trained segmentation model, in which the feature extraction parts of the network have been pre-trained on imagenet (A large database of natural images). This should mean we can train a good model with relatively little data, but also means we have to convert our 1 channel CT image into a 3 channel RGB image. Fortunately, the library can do the 1 -> 3 channel conversion for us by repeating the image in the channel dimension.

We will be using an architecture called a Feature Pyramid Network or [FPN](), it looks something like this:

![FPN architecture, from segmentation_models github page](https://github.com/qubvel/segmentation_models/raw/master/images/fpn.png)

The part on the left, in gray, is the bit that is pretrained on imagenet; this is called the backbone of the network. Pretraining means that the backbone already knows about some features in images that are useful, in this case to classify them. We will use these features as a starting point from which we will learn features useful to our task - segmentation in CT.

The choice of backbone is somewhat arbitrary - more modern and bigger networks should have better performance, but not always. We will use the smallest network available, ResNet-18 because it will hopefully train faster and use less GPU memory. 

As I alluded to at the beginning, writing/explaining a pytorch training loop in the ~1 hour we have is probably not realistic, so we are going to use a library called pytorch-lightning to do all the heavy lifting for us. To be able to use it, we have to wrap our pytorch model in a special class that inherits from a LightningModule (If none of that makes sense, don't worry - this bit is making the model so that pytorch-lightning knows what to do with it). To do the wrapping, we create a class LightningFPN, and define a few mandatory functions for pytorch-lightning to use.

While creating the wrapper, we also have to select which optimiser to use, and the most appropriate loss function for the problem. There are lots of considerations in selecting an optimiser, but most people use one of either Stochastic Gradient Descent (SGD) with momentum, or Adaptive Moment Estimation (Adam). Usually, Adam converges faster than SGD because of the fancy stuff it is doing inside the optimiser, but sometimes SGD can find a better solution by avoiding a local minimum. In the interests of speed, we will use Adam.

The choice of loss function can also have a profound impact on the quality of the trained model. There are a few loss functions we could consider here. Segmentation is just a classification problem applied to every pixel in the image, so we could use a loss designed for classification, apply it in every pixel and then take the average across the whole image. This is what the categorical crossentropy loss will do. Categorical crossentropy can fall over though - especially when there is a class imbalance. Since we are segmenting organs that occupy a small fraction of the image, we have a lot of background and not much foreground, therefore the loss will be dominated by the background. There are ways around this (e.g. weighted cross entropy, or focal loss) but as a first attempt, we should probably use something else.

The Dice similarity coefficient (DSC) is very well known, and can be used as a loss function here. DSC is rightly villified in radiotherapy because it has no spatial component (how far apart were the contours) and is very sensitive to volume (small contours will always be much worse than big ones). The segmentation_models library has implementations of all these losses; for now we will just use DSC, but it would be trivial to use one of the other losses.


In [None]:
## Define the class that will wrap te pytorch model up for ptl
class LightningFPN(pl.LightningModule):
  def __init__(self):
    super().__init__()
    ## Create the pytorch model 
    self.model = smp.FPN("resnet18", in_channels=1, classes=len(structure_names)+1, encoder_weights='imagenet')
    
    ## Construct a loss function, this is DSC, configured for multiple classes, and ignoring the background
    self.loss_fcn = smp.losses.DiceLoss("multiclass", from_logits=True)

    ## Specify which optimiser to use here
    self.optimizer = torch.optim.Adam

  def forward(self, x):
    return self.model(x)

  def configure_optimizers(self):
    optimizer = self.optimizer(self.parameters(), lr=1e-4)## May need to handle other kwargs here!
    return {"optimizer": optimizer, "reduce_on_plateau":True}
    ## Note - we are reducing the learning rate when the validation loss plateaus for a while - this should improve the model


  def training_step(self, batch, batch_idx):
    img, msk, _ = batch
    msk_hat = self(img)
    loss = self.loss_fcn(msk_hat, msk.long())
    self.log("loss", loss)
    return loss

  def validation_step(self, batch, batch_idx):
    img, msk, _ = batch
    msk_hat = self(img)
    val_loss = self.loss_fcn(msk_hat, msk.long())
    self.log("val_loss", val_loss)
    return val_loss


## Now we can wrap the prebuilt model up inside a pytorch lightning module:

pl_model = LightningFPN()

## Done!


This model has about 13 million parameters, which might stretch our GPU a bit - we will need to think about batch size if this becomes an issue during training.

# 4. Training

We're now ready to train the model. It is very easy when using pytorch-lightning, only taking 2 lines to do what would be hundreds in pure pytorch!

To be able to keep track of what is going on, we will use the tensorboard log viewer, which I activate in the next cell. This will allow us to see the training and validation loss change as the network learns.

To keep things quick, we will only train for 5 epochs. Ideally, we would train for a few hundred, to make sure the network loss is properly saturated.


In [None]:
# %tensorboard --logdir lightning_logs/

In [None]:
trainer = pl.Trainer(max_epochs=5)
trainer.fit(pl_model, train_dataloader, val_dataloader)

You should be able to see the training progress in the tensorboard browser. 5 epochs is not enough, but as we will see, it is actually surprisingly good...


In [None]:
# save a checkpoint in case colab crashes
trainer.save_checkpoint("model.ckpt")

# code to reload the model if colab crashes
#pl_model = LightningFPN.load_from_checkpoint("model.ckpt")

We can now try running this model on some test data. I've chosen a slice somewhat at random so that it has all the structures, in a moment we will run the segmentation over the whole test set.

In [None]:
in_image, in_mask, in_spacing = next(iter(val_dataloader))

# take the first example from the batch
in_image = in_image[0].unsqueeze(0)
all_gt = in_mask[0].numpy().astype(float)
in_spacing = in_spacing[0].numpy()

all_gt[all_gt == 0.0] = np.nan ## make the background invisible
test = pl_model.model.predict(in_image).detach() ## the model dosn't do softmax activation, so we have to do it ourselves
probs = torch.nn.functional.softmax(test, dim=1)[0].numpy().squeeze()
test_mask = np.argmax(probs, axis=0).astype(float)
test_mask[test_mask == 0.0] = np.nan


## Show the results
fig =  plt.figure(figsize=(10,10)) 
ax_gt = fig.add_subplot(121)
ax_gt.set_title("Ground Truth")
ax_cnn = fig.add_subplot(122)
ax_cnn.set_title("CNN contour")

ax_gt.imshow(in_image.squeeze(), cmap='Greys_r')
ax_gt.imshow(all_gt.squeeze(), alpha=0.75, cmap='viridis', vmin=0, vmax=5)
ax_gt.invert_yaxis()

ax_cnn.imshow(in_image.squeeze(), cmap='Greys_r')
ax_cnn.imshow(test_mask, alpha=0.75, cmap='viridis', vmin=0, vmax=5)
ax_cnn.invert_yaxis()

If training a segmentation model is no longer possible in colab, I have pre-trained a model on the exact same data as here, but for 200 epochs...

In [None]:
# model_trained = pl_model.load_from_checkpoint(os.path.join(datapath, "pretrained_checkpoint.ckpt"))

# in_image, in_mask, in_spacing = next(iter(val_dataloader))

# # take the first example from the batch
# in_image = in_image[0].unsqueeze(0)
# all_gt = in_mask[0].numpy().astype(float)
# in_spacing = in_spacing[0].numpy()

# all_gt[all_gt == 0.0] = np.nan ## make the background invisible
# test = pl_model.model.predict(in_image).detach() ## the model dosn't do softmax activation, so we have to do it ourselves
# probs = torch.nn.functional.softmax(test, dim=1)[0].numpy().squeeze()
# test_mask = np.argmax(probs, axis=0).astype(float)
# test_mask[test_mask == 0.0] = np.nan


# ## Show the results
# fig =  plt.figure(figsize=(10,10)) 
# ax_gt = fig.add_subplot(121)
# ax_gt.set_title("Ground Truth")
# ax_cnn = fig.add_subplot(122)
# ax_cnn.set_title("CNN contour")

# ax_gt.imshow(in_image.squeeze(), cmap='Greys_r')
# ax_gt.imshow(all_gt.squeeze(), alpha=0.75, cmap='viridis', vmin=1, vmax=5)
# ax_gt.invert_yaxis()

# ax_cnn.imshow(in_image.squeeze(), cmap='Greys_r')
# ax_cnn.imshow(test_mask, alpha=0.75, cmap='viridis', vmin=1, vmax=5)
# ax_cnn.invert_yaxis()

# 5. Evaluation

Visually, the results look pretty good, but now let's try to quantify it a bit. We're going to use a few metrics: 
- Dice Coefficient - the same things as what we used for the loss, just measures overlap and is sensitive to 
- Mean surface distance - The mean of the directed hausdorf distance from GT -> CNN and CNN -> GT
- 95th percentile Hausdorff distance - the mean of the directed 95th percentile HD from GT -> CNN and CNN -> GT


In [None]:
from utils import compute_dice_coefficient, compute_surface_distances, compute_average_surface_distance, compute_robust_hausdorff

distance_measures = {name : {} for name in structure_names[1:]}
organ_labels = [2,3,4,5]

for name, organ_label in zip(structure_names[1:], organ_labels):

    cnn_prediction = test_mask == organ_label
    gt = all_gt == organ_label

    if np.sum(gt) == 0:
        print(f"Skipping {name} as there are no ground truth labels in this slice")
        continue

    if np.sum(cnn_prediction) == 0:
        print(f"Skipping {name} as there are no CNN predictions in this slice")
        continue

    surface_distances = compute_surface_distances(gt, cnn_prediction, spacing_mm=in_spacing)

    distance_measures[name]['msd'] = compute_average_surface_distance(surface_distances)
    distance_measures[name]["HD95"] = compute_robust_hausdorff(surface_distances, 95)
    distance_measures[name]["DSC"] = compute_dice_coefficient(gt, cnn_prediction)

    print(f"{name}\n\tMSD: {distance_measures[name]['msd']:.3f} mm\n\tHD95: {distance_measures[name]['HD95']:.3f} mm\n\tDSC: {distance_measures[name]['DSC']:.3f}\n")



Let's now run the evaluation over the entire test set!

In [None]:
MSDs = {name : [] for name in structure_names[1:]}
HD95s = {name : [] for name in structure_names[1:]}
DSCs = {name : [] for name in structure_names[1:]}
failures = {name : 0 for name in structure_names[1:]}
organ_labels = [2,3,4,5]

## Now create some the test dataset and dataloader
spacings_test = np.array(list(test_pixel_sizes.values()))
test_dataset = LCTSCDataGen(window_levelled_slices_test, test_mask_slices, val_transforms, spacings_test)
## Create dataloaders from these datasets
test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=1)


for in_image, in_mask, in_spacing in test_dataloader:
    # take the first example from the batch
    in_image = in_image[0].unsqueeze(0)
    all_gt = in_mask[0].numpy().astype(float)
    in_spacing = in_spacing[0].numpy()

    test = pl_model.model.predict(in_image).detach() ## the model dosn't do softmax activation, so we have to do it ourselves
    probs = torch.nn.functional.softmax(test, dim=1)[0].numpy().squeeze()
    test_mask = np.argmax(probs, axis=0).astype(float)

    for name, organ_label in zip(structure_names[1:], organ_labels):
        cnn_prediction = test_mask == organ_label
        gt = all_gt == organ_label

        if np.sum(gt) == 0:
            continue

        if np.sum(cnn_prediction) == 0:
            failures[name] += 1
            continue

        surface_distances = compute_surface_distances(gt, cnn_prediction, spacing_mm=in_spacing)
        MSDs[name].append(compute_average_surface_distance(surface_distances))
        HD95s[name].append(compute_robust_hausdorff(surface_distances, 95))
        DSCs[name].append(compute_dice_coefficient(gt, cnn_prediction))


print("Mean MSDs")
for name in structure_names[1:]:
    print(f"\t{name}: {np.mean(MSDs[name]):.3f} mm")

print("\nMean HD95s")
for name in structure_names[1:]:
    print(f"\t{name}: {np.mean(HD95s[name]):.3f} mm")

print("\nMean DSCs")
for name in structure_names[1:]:
    print(f"\t{name}: {np.mean(DSCs[name]):.3f}")

print("\nFailures")
for name in structure_names[1:]:
    print(f"\t{name}: {failures[name]}")

# 7. Visualisation

At this point, take some time visualising the results using matplotlib. We'd like to see some boxplots of the results for each organ...

In [None]:
fig, ax = plt.subplots()

### Your code here...

# 8. A new dataset

Let's now look at a new anatomical site ... the head and neck (HnN)

First of all let's try running our previous model on this new dataset.

In [None]:
## See if we're in colab...
if IN_COLAB:
  hn_datapath = "/content/HnN_data/"
else:
  hn_datapath = "/home/ed/autoseg_workshop_2023/HnN_data/"  ## <---- You will need to change this if running locally

hn_fnames = sorted(getFiles(join(hn_datapath, 'train', 'ims')))
hn_image = window_level(np.load(join(hn_datapath, 'train', 'ims', hn_fnames[23])))
hn_mask = np.load(join(hn_datapath, 'train', 'masks', hn_fnames[23]))
hn_image = val_transforms(image=hn_image[np.newaxis, np.newaxis])['image']
hn_mask[hn_mask == 0.0] = np.nan ## make the background invisible
test = pl_model.model.predict(torch.tensor(hn_image)).detach() ## the model dosn't do softmax activation, so we have to do it ourselves
probs = torch.nn.functional.softmax(test, dim=1)[0].numpy().squeeze()
test_mask = np.argmax(probs, axis=0).astype(float)
test_mask[test_mask == 0.0] = np.nan

## Show the results
fig =  plt.figure(figsize=(10,10)) 
ax_gt = fig.add_subplot(121)
ax_gt.set_title("Ground Truth")
ax_cnn = fig.add_subplot(122)
ax_cnn.set_title("CNN contour")

ax_gt.imshow(hn_image.squeeze(), cmap='Greys_r')
ax_gt.imshow(hn_mask.squeeze(), alpha=0.75, cmap='viridis', vmin=0, vmax=5)
ax_gt.invert_yaxis()

ax_cnn.imshow(hn_image.squeeze(), cmap='Greys_r')
ax_cnn.imshow(test_mask, alpha=0.75, cmap='viridis', vmin=0, vmax=5)
ax_cnn.invert_yaxis()


As we can see - the abdominal segmentation model does terribly in the head and neck images. This is because the model has never been trained on this anatomical site. We'll have to train a new Head and Neck specific model to handle this new data.

Make a copy of this notebook and retrain the CNN on the new HnN data.