![banner](./assets/banner.jpeg)

> You can open this notebook in Colab by clicking the Colab icon. Colab provides GPU for free. You can also run this notebook locally by installing the dependencies listed in `requirements.txt`.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/osbm/unet_explainer/blob/main/tutorial-part2-solutions.ipynb)

### Preamble

The second segment of this hands-on workshop aims to provide an in-depth understanding of the renowned U-Net deep learning architecture, specifically tailored for the segmentation of 2D multi-slice prostate MR images. For this exercise, we will be utilizing the [prostate-158 train dataset](https://zenodo.org/record/6481141), which comprises 139 MRI images, and the [prostate-158 test dataset](https://zenodo.org/record/6592345), containing 19 MRI images.

**Note for Medical Professionals:** The U-Net architecture is particularly advantageous in medical imaging for its efficiency in segmenting intricate structures, which is a crucial step in the diagnostic and treatment planning process.

### Objectives

* To acquaint participants with the process of coding a deep learning segmentation method using the PyTorch library.
* To evaluate the performance impact of various key hyper-parameters on a U-Net model.

**Clinical Relevance:** Understanding the influence of hyper-parameters can be instrumental in fine-tuning the model for specific clinical applications, thereby potentially improving diagnostic accuracy and patient outcomes.

By adding these notes, you can better engage with your audience by outlining the practical relevance and implications of the material. Would you like to proceed with the next cell?

### Our works

- Karagoz, Ahmet, et al. "Anatomically guided self-adapting deep neural network for clinically significant prostate cancer detection on bi-parametric MRI: a multi-center study." Insights into Imaging 14.1 (2023): 1-11. https://doi.org/10.1186/s13244-023-01439-0

- Karagoz, Ahmet, et al. "Prostate Lesion Estimation using Prostate Masks from Biparametric MRI." arXiv preprint arXiv:2301.09673 (2023). https://doi.org/10.48550/arXiv.2301.09673

#### Installing Required Software and Libraries:

For your convenience, I have created a Python package that encapsulates all the necessary dependencies and libraries required for this tutorial. You can install this package using the command provided below. By doing so, you will automatically install all the libraries that are essential for this workshop.

For those interested in understanding the underlying implementation of the functions and classes used in this tutorial, the source code is publicly available in the [unet_explainer GitHub repository](https://github.com/osbm/unet_explainer/).

#### Note to Participants:

1. **Why a Custom Python Package?**: Packaging the required libraries and dependencies into a single installable unit simplifies the setup process, allowing you to focus more on the tutorial's content rather than troubleshooting installation issues.

2. **Transparency and Extensibility**: The availability of the source code in a public repository offers transparency and provides an opportunity for future customization and improvement. You are encouraged to explore the repository to gain deeper insights into the functionalities provided.

By adhering to the above installation steps, you ensure a smooth and efficient setup, allowing us to dive straight into the core topics of this workshop.

## Reminder on the U-Net architecture

U-Net is based on a two-stage convolutional network architecture. The first part, known as the encoder, is similar to conventional CNNs and extracts high-level information. The second part is the decoder, which uses information from the encoder and applies a set of convolutions and upsampling operations to gradually transform feature maps with the purpose of reconstructing segmentation maps at the resolution of the imput image. U-Net architecture also integrates skip connections between the encoder and decoder parts with the goal of retrieving details that were potentially lost during the downsampling while also stabilizing the learning procedure. An illustration of the network architecture is given below.


![unet-architecture](https://github.com/osbm/unet_explainer/blob/main/assets/unet-architecture.png?raw=1)

The U-Net architecture can be defined through the following main parameters:
- the number of feature maps at the first level
- the number of levels
- the use of the batch normalizations at each level
- the type of activation functions
- the use of dropout operations
- the use of data augmentation

The performance of deep learning model also depends on the optimization conditions that were used during the learning process, the main ones being:
- the optimization algorithm (*ADAM* and *RMSprop* being among the most popular)
- the learning rate


In [None]:
# This is a public repository, please inspect the code if you are curious about the implementation details.
!pip install git+https://github.com/osbm/unet_explainer.git

# Imports

Import all the used built-in, custom and third-party libraries to use in this notebook.

In [None]:
# lets import our functions and classes
from unet_pytorch import (
    ProstateDataset,          # our dataset class
    print_model_info,         # function to print pytorch model info
    fit_model,                # function to train and validate pytorch model
    predict,                  # function to predict on pytorch model
    set_seed,                 # helper function to set seed for reproducibility
    plot_overlay_4x4,         # helper function to plot 4x4 grid of images
    plot_predictions,         # helper function to plot predictions
    plot_one_example,         # helper function to plot one example
    plot_history,             # helper function to plot training history
    plot_comparison_examples, # helper function compare augmented and original images
)
# third party libraries
import torch                                  # pytorch deep learning framework
from torch.utils.data import DataLoader       # dataloader class from pytorch to load data

import monai                                  # monai medical imaging framework (built on top of pytorch)
from monai.networks.nets import UNet          # unet model from monai (there are other models that you use with a single line)

import albumentations as A                    # albumentations library for image augmentation and preprocessing
from albumentations.pytorch import ToTensorV2 # albumentations class to convert images to tensors

# built-in libraries
import os

### Lets download the data

I have uploaded a preprocessed version of the data on a huggingface dataset. You can download it by running the following cell.

In [None]:
if not os.path.exists('data'): # download data if it doesn't exist
    # download data
    !wget -q https://huggingface.co/datasets/osbm/unet-explainer-data/resolve/main/data.zip
    # unzip data
    !unzip -q data.zip

# -q flag means quiet, so you won't see any output

# Reproducibility

We should set an RNG seed for reproducibility. This way we can get the same results on each run. This is important for debugging and comparing different models. Also is useful if you want to prove that you didnt forge your results.

> Warning: Total deterministic behavior is not guaranteed between PyTorch releases, individual commits or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds. For this reason it is recommended to also share python version, exact PyTorch version and platform (OS, GPU etc.) when reporting results.

In [None]:
set_seed(42)

# Data augmentation

This is a weird step. We are purposely degrading the quality of the data. This is done to make the model more robust to noise and other artifacts. There is also randomness in the augmentation. This way the model will see different images in each epoch. This is done to prevent overfitting. We are basically increasing the size of the dataset for free by augmenting the data.

![augmentation](./assets/data-augmentation.jpeg)

Look at all these images. You would want your model to be able to detect the cat in all of these images. This is what data augmentation does. It makes the model more robust to noise and other artifacts.


### Data Preprocessing and Augmentation Code

Here's the code snippet for data preprocessing and augmentation using Albumentations library:

```python
image_size = 128
train_transforms = A.Compose([ # data augmentation and preprocessing pipeline
    A.Resize(height=image_size, width=image_size),
    A.HorizontalFlip(p=0.5), # probability of 0.5
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=5, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    ToTensorV2(),
])

valid_transforms = A.Compose([ # only data preprocessing pipeline
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(),
])
```

**Technical Note:** The `train_transforms` pipeline includes a series of data augmentation techniques such as horizontal flip, vertical flip, and rotation, among others. These augmentations are critical for enhancing the model's robustness to variations in the data. The `valid_transforms` pipeline, on the other hand, is solely for data resizing and tensor conversion and does not include data augmentation.

**Clinical Insight:** Data augmentation methods like flipping and rotation are especially useful in medical imaging to simulate different orientations and positions of anatomical structures. This helps in training a more generalized model, which is crucial for real-world applications where variability is inherent.

Would you like to review the next cell?

In [None]:
image_size = 128
train_transforms = A.Compose([ # data augmentation and preprocessing pipeline
    A.Resize(height=image_size, width=image_size),
    A.HorizontalFlip(p=0.5), # probability of 0.5
    A.VerticalFlip(p=0.5),
    A.Rotate(limit=5, p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.5),
    A.RandomBrightnessContrast(p=0.5),
    ToTensorV2(),
])

valid_transforms = A.Compose([ # only data preprocessing pipeline
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(),
])

In [None]:
# lets create our dataset objects
train_ds = ProstateDataset(folder='data/train', transform=train_transforms)
valid_ds = ProstateDataset(folder='data/valid', transform=valid_transforms)
test_ds = ProstateDataset(folder='data/test', transform=valid_transforms)

### DataLoader Configuration Code

Here is the code for setting up DataLoader objects for the training, validation, and testing datasets:

```python
# Create DataLoader objects
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_ds, batch_size=16, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)
```

**Technical Note:** The `DataLoader` objects are configured with a batch size of 16. The training dataset (`train_loader`) is set to shuffle the data before each epoch, which helps in breaking any correlations in the sequence of input data and thereby improves model training. For the validation (`valid_loader`) and test (`test_loader`) datasets, shuffling is disabled as it is not generally required during the evaluation phase.

**Clinical Insight:** The choice of batch size can have a significant impact on both the training speed and the performance of the model. A smaller batch size may offer a more precise estimate of the gradient, but it may also require more iterations to converge. Clinical practitioners should be aware that these hyperparameters can be adjusted based on the specific computational and clinical requirements.

Would you like to continue to the next cell?

In [None]:
# and also create our dataloader objects
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_ds, batch_size=16, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

## Lets see augmented training data and compare it with the un-augmented validation data

In [None]:
for x, y in train_loader:
    print(x.shape) # shape = [batch_size, channels, height, width]
    print(y.shape) # shape = [batch_size, channels, height, width]
    plot_one_example(x[0], y[0])
    print("Overlay examples:")
    plot_overlay_4x4((x, y))
    break

Now that we know how the train set is looking let us compare a random example from the validation set and see how it looks like.

In [None]:
train_example = train_ds[42]
valid_example = valid_ds[42]

plot_comparison_examples(train_example, valid_example)

### Model Configuration Code

Below is the code snippet for configuring the U-Net model with specific hyperparameters:

```python
model = UNet( # these are the hyperparameters we can change
    spatial_dims=2, # 2D image
    in_channels=1,  # we only used T2-weighted MRI images
    out_channels=3, # 3 labels
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2), # CNN strides
    num_res_units=4, # residual connections
    dropout=0.15, # dropout rate
)
print_model_info(model)
```

**Technical Note:** The `UNet` model is initialized with various hyperparameters like spatial dimensions, input and output channels, the number of filters at each layer, strides for the convolutions, number of residual units, and the dropout rate. Each of these parameters can significantly impact the model's performance and should be carefully chosen.

**Clinical Insight:** The model is tailored for 2D T2-weighted MRI images with three distinct labels. This specific configuration can be useful for segmenting different zones or tissues within the prostate gland. Understanding the role of each hyperparameter can help in customizing the model for specific clinical tasks, thereby enhancing its utility in practice.

Would you like to proceed to the next cell?

In [None]:
model = UNet( # these are the hyperparameters we can change
    spatial_dims=2, # 2d image
    in_channels=1,  # we only used  T2 weighed MRI images
    out_channels=3, # 3 labels
    channels=[16, 32, 64, 128, 256, 512],
    strides=(2, 2, 2, 2, 2), # CNN strides
    num_res_units=4, # residual connections
    dropout=0.15, # dropout rate
)
print_model_info(model)

# Optimizer

An optimizer is an algorithm that is used to update the weights of the neural network. The most popular optimizers are Adam, SGD, RMSprop, etc.

<!-- ![optimizers](./assets/optimizer.jpeg) -->

<img src="./assets/optimizer.jpeg" alt="drawing" width="500"/>

Lets imagine above image. This example neural network has only 2 parameters. For all these parameters we have different loss values. And these loss values are represented by height. Optimizer is trying to find the global minima. It does this by taking small steps in the direction of the steepest slope. This is called gradient descent. There are different types of gradient descent. The most popular one is Adam.

This is just a visualization for 2 parametered model. Now imagine 60 million parameters. That is an huge space to search. We cant just simply search the whole space. Our models have to take small steps. If the steps are too small then it will take a long time to reach the global minima. If the steps are too big then we might miss and jump around the global minima back and forth.

This is one of the reasons there are things like learning rate schedulers. They change the learning rate over time or dependent on the current loss. This way we can take big steps in the beginning and small steps in the end. This way we can reach the global minima faster.


In [None]:
# we will use the Adam optimizer with a learning rate of 0.0001
optimizer = torch.optim.Adam(model.parameters(), 1e-4)

# Loss function

Loss function is a function that is used to calculate the error of the model. It gives the answer to the "How much did the model get wrong?" question. 

There are different types of loss functions. It is often defined by the task. For example, for classification tasks we use cross entropy loss. For regression tasks we use mean squared error loss. For segmentation tasks we use dice loss.

## Dice Loss

Dice loss is a loss function that is used for segmentation tasks. It is defined as follows:

![loss](./assets/dice-loss.jpeg)



In [None]:
# we will use the Dice loss function from monai. You can also use other compount loss functions from monai
loss = monai.losses.DiceLoss(include_background=True, to_onehot_y=True, softmax=True)

### Model Training Code

Here is the code for transferring the model weights to the GPU (if available) and initiating the training process:

```python
# Move model weights to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Start training
model, history = fit_model(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    loss=loss,
    device=device,
    epochs=30,
)
```

**Technical Note:** The model is first moved to the GPU using `torch.device`, which ensures that all computations are performed on the GPU if available, thereby speeding up the training process. The training is then initiated using the `fit_model` function, where various parameters like loaders, optimizer, loss function, and the number of epochs are passed.

**Clinical Insight:** Moving the model to a GPU can significantly accelerate the training time, which is especially crucial for medical applications where timely diagnosis and intervention can be life-saving. The choice of loss function and optimizer also has implications for the model's ability to generalize well to new, unseen medical data. Practitioners should be aware that the number of epochs is another tunable parameter that needs to be optimized based on the performance requirements and available computational resources.

Would you like to proceed to the next cell?

In [None]:
# Move model weights to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define optimizer and loss here if not already defined
# optimizer = ...
# loss = ...

# Custom fit_model function or equivalent should be defined here to ensure compatibility
# def fit_model(...):
#     ...

try:
    # Start training
    model, history = fit_model(
        model=model,
        train_loader=train_loader,
        valid_loader=valid_loader,
        optimizer=optimizer,
        loss=loss,
        device=device,
        epochs=2,
    )
except ValueError as e:
    print(f"An error occurred: {e}")
    # Add any debugging or logging code here


In [None]:
# lets see our training history
plot_history(history)

In [None]:
# our function is keeping track of the best model based on validation loss
# just so that we can load the model in its best state, not after we realized that it started overfitting
model.load_state_dict(torch.load('best_model.pth', map_location=device))

In [None]:
# lets see how our model performs on the test set
x, y, y_hat = predict(model, test_loader=test_loader, device=device, final_activation="softmax")

In [None]:
# here is the entire test set with predictions
print(x.shape, y.shape, y_hat.shape)

In [None]:
plot_predictions(x, y, y_hat, num_examples_to_plot=10)

# DONE

Now you know roughly how U-Net training workflow works.
You can still learn so much just by playing around with:
- other segmentation models
- other data augmentation techniques
- other loss functions
- other optimizers
- other learning rate schedulers

But I think this is enough for now. I hope you enjoyed this tutorial. If you have any questions or suggestions, please feel free to contact me.