# Train a classifier to select the appropriate segmentation model

**The following cells are used to load data into the session. You shouldn't have to edit them, just make sure they work!**

In [None]:
%pip install torchvision

In [None]:
import segmentation_models_pytorch as smp
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from utils import getFiles
from os.path import join
import numpy as np
import random
import google.colab
import torch
import torchvision

In [None]:
# Load some of the abdominal and HnN data from the directories already created.

# setup paths
abdo_datapath = "/content/FLARE_data/train/ims/"
hn_datapath = "/content/HnN_data/train/ims/"
abdo_datapath = "./FLARE_data/"
hn_datapath = "./HnN_data/"

# get fnames
abdo_fnames = sorted(getFiles(abdo_datapath))
hn_fnames = sorted(getFiles(hn_datapath))

# sample 400 images of each dataset
random.seed(1234)
abdo_fnames_to_use = random.sample(abdo_fnames, k=400)
hn_fnames_to_use = random.sample(hn_fnames, k=400)

# load the images
train_ct_slices = np.zeros((400+400, 256, 256), dtype=float)
# abdomen first
for adx, fname in enumerate(abdo_fnames_to_use):
    train_ct_slices[adx] = np.load(join(abdo_datapath, fname))
# now the head and neck
for hdx, fname in enumerate(hn_fnames_to_use):
    im = np.load(join(hn_datapath, fname))
    padded_im = np.pad(im, pad_width=32, constant_values=-1024)
    train_ct_slices[hdx+400] = padded_im

# create a labels array (array 0s - abdomenal ct, 1s = HnN ct)
ground_truth_labels = np.zeros((800))
ground_truth_labels[400:] = 1

# shuffle the data!
np.random.seed(1234)
shuffle_indices = np.random.permutation(800)
train_ct_slices = train_ct_slices[shuffle_indices]
ground_truth_labels = ground_truth_labels[shuffle_indices]

**Quick sanity check**:

In [None]:

fig, ax = plt.subplots()
n = 47
print(ground_truth_labels[n])
ax.imshow(train_ct_slices[n], cmap='Greys_r')
ax.invert_yaxis()
plt.show()

#### Download the unlabelled test data - you will need this later

In [None]:
## Download the test data
try:
    
    IN_COLAB = True
    !wget https://www.dropbox.com/s/9950rhyrs9kb8kj/classification_test_data.tar.gz?dl=0 -O /content/ClassificationProcessedData.tar
    !tar -xf /content/ClassificationProcessedData.tar -C /content/
    !rm /content/ClassificationProcessedData.tar
except:
    IN_COLAB = False
    !wget https://www.dropbox.com/s/9950rhyrs9kb8kj/classification_test_data.tar.gz?dl=0 -O ./ClassificationProcessedData.tar
    !tar -xf ./ClassificationProcessedData.tar -C ./
    !rm ClassificationProcessedData.tar

---

# Your work starts here!


## Pre-processing

Think about what pre-processing to apply

## Dataset

Define a dataset class, **split into train/validation set**, think about augmentations you should use.

## Create your model

Try using a simple classification model from [torchvision](https://pytorch.org/vision/stable/models.html) (e.g. ResNet18). Think about what loss and optimiser to use.

In [None]:
# Torchvision should be imported, load a model with torchvision.models.MODEL_NAME
## weights = 'IMAGENET1K_V1' tells torchvision to load the model with weights trained on IMAGENET
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

## Loops

We encourage you to write your own training loops in Pytorch as you'll get a real understanding of what is happening "under-the-hood". 

_You can always use `pytorch-lightning` if you don't want to..._

## Training!

Train your model for a few epochs. Remember to monitor both the training and validation losses!

# Pipeline

### Classify & segment the testing images

1. Run the classifier on each image in your dataset.
2. Based on the output from step 1, select the appropriate segmentation model.
3. Segment the images using the saved models from Part 1 & 2.
4. Measure segmentation performance.
   - What other results can you extract? Does segmentation performance depend on structure size? Is segmentation performance better or worse in head and neck? Why might this be?

### Load the segmentation models from part 1 & 2 

In [None]:
class LightningFPN(pl.LightningModule):
  def __init__(self, structure_names):
    super().__init__()
    ## Create the pytorch model 
    self.model = smp.FPN("resnet18", in_channels=1, classes=len(structure_names)+1, encoder_weights='imagenet')
    
    ## Construct a loss function, this is DSC, configured for multiple classes, and ignoring the background
    self.loss_fcn = smp.losses.DiceLoss("multiclass", from_logits=True)

    ## Specify which optimiser to use here
    self.optimizer = torch.optim.Adam

  def forward(self, x):
    return self.model(x)

  def configure_optimizers(self):
    optimizer = self.optimizer(self.parameters(), lr=1e-4)## May need to handle other kwargs here!
    return {"optimizer": optimizer, "reduce_on_plateau":True}
    ## Note - we are reducing the learning rate when the validation loss plateaus for a while - this should improve the model

  def training_step(self, batch, batch_idx):
    # Separate batch into input, mask and spacing (although we ignore the spacing)
    img, msk, _ = batch
    # Pass the input through the model and get a prediction (msk_hat)
    msk_hat = self(img)
    # Calculate average prediction error on this batch
    loss = self.loss_fcn(msk_hat, msk.long())
    # Log the error 
    self.log("loss", loss)
    return loss

  def validation_step(self, batch, batch_idx):
    # Identical to the training step
    img, msk, _ = batch
    msk_hat = self(img)
    val_loss = self.loss_fcn(msk_hat, msk.long())
    self.log("val_loss", val_loss)
    return val_loss

In [None]:
## Load the segmentation models from part 1 & 2
# !! You'll need to update PATH_TO_WEIGHTS with the appropriate path 
# Abdominal model

abdominal_model = LightningFPN(
    ['Body', 'Liver', 'Kidneys', 'Spleen', 'Pancreas']).load_from_checkpoint(PATH_TO_WEIGHTS)

# Head and neck (hnn) model
hnn_model = LightningFPN(
    ['Body', 'Brainstem', 'Mandible', 'Parotids', 'Spinalcord']).load_from_checkpoint(PATH_TO_WEIGHTS)



In [1]:
## 1. Inference loop, iterate over your dataset and use classifier to predict anatomical location
## Remember to apply an activation to your model predictions

    ## 2 + 3. Based on the prediction, pass the image to the correct segmentation model

    ## 4. Measure performance

## 