> You can open this notebook in Colab by clicking the Colab icon. Colab provides GPU for free. You can also run this notebook locally by installing the dependencies listed in `requirements.txt`.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/osbm/unet_explainer/blob/main/tutorial-part1.ipynb)

### Preamble

First part of this hands-on will show you a working example of pretrained deep learning model on semantic segmentation on prostate images, in the context of 2D multi-slice prostate MR images segmentation. You will use the (prostate-158 test dataset)[https://zenodo.org/record/6592345] dataset (19 mri images).

### Objectives of this part

* Understand the basics of semantic segmentation
* Understand the basics of U-Net architecture
* Understand how does one predict on prostate images using a pretrained model (in next part, you will learn how to train a model from scratch)


## Install requirements:

I have made a python package that you can install with the below command. This package will also automatically install all the libraries needed for this tutorial. If you wanto see how the functions and the classes are implemented, you can check the source code in [unet_explainer](https://github.com/osbm/unet_explainer/) repository.

In [None]:
# This is a public repository, please inspect the code if you are curious about the implementation details.
!pip install git+https://github.com/osbm/unet_explainer.git

# Imports

Import all the used built-in, custom and third-party libraries to use in this notebook.

In [None]:
# lets import our functions and classes
from unet_pytorch import (
    ProstateDataset,     # Our custom dataset class
    print_model_info,    # Function to print pytorch model info
    predict,             # Function to predict on a single image
    set_seed,            # Function to set seed for reproducibility
    plot_predictions,    # Function to plot predictions
    plot_one_example,    # Function to plot one example
)

# third party libraries
import torch                                  # pytorch deep learning framework
from torch.utils.data import DataLoader       # dataloader class from pytorch to load data

import monai                                  # monai medical imaging framework (built on top of pytorch)
from monai.networks.nets import UNet          # unet model from monai (there are other models that you use with a single line)

import albumentations as A                    # albumentations library for image augmentation and preprocessing
from albumentations.pytorch import ToTensorV2 # albumentations class to convert images to tensors

# built-in libraries
import os  # os library to work with files and directories

# Download the dataset and the pretrained model

I have put the preprocessed version of the dataset and the pretrained model in a public huggingface dataset repository. You can download them with the below commands.

One thing you need to know about this dataset is that i have applied some preprocessing on this dataset using [this script](https://github.com/osbm/unet_explainer/blob/main/scripts/preprocess_data.py). To summarize, the original prostate158 dataset has 3d volume image and masks of patients. But our model is doing 2d segmentation. So i have filtered all the slices that its masks contain segmentation masks below 6 percent. This is an arbitrary number, you can play around with this number to create new datasets.

The point is if a slice does not contain *enough* of the prostate, we do not want to use it for training.


In [None]:
if not os.path.exists('data'): # download data if it doesn't exist
    # download data
    !wget -q https://huggingface.co/datasets/osbm/unet-explainer-data/resolve/main/data.zip
    # unzip data
    !unzip -q data.zip

# -q flag means quiet, so you won't see any output

In [None]:
# also download our pretrained model
!wget -q https://huggingface.co/datasets/osbm/unet-explainer-data/resolve/main/best_model.pth

# Reproducibility

We should set an RNG seed for reproducibility. This way we can get the same results on each run. This is important for debugging and comparing different models. Also is useful if you want to prove that you didnt forge your results.

> Warning: Total deterministic behavior is not guaranteed between PyTorch releases, individual commits or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds. For this reason it is recommended to also share python version, exact PyTorch version and platform (OS, GPU etc.) when reporting results.

In [None]:
set_seed(42)

# Transformations

There are couple of transformations that we need to apply to our images and masks. We need to resize them into a fixed size, we need to convert them into tensors and we need to normalize them. We will use the albumentations library for this purpose. Albumentations is a popular library for image augmentation.

The most common transformation libraries are:
- Torchvision (native to pytorch)
- Albumentations
- Monai (mostly for medical images)

In [None]:
image_size = 256
valid_transforms = A.Compose([
    A.Resize(height=image_size, width=image_size),
    ToTensorV2(),
])
# all we are doing here is resizing the image to 256x256 and converting it to a tensor

# Dataset and dataloader

In pytorch, we need to create a custom dataset class. And then we need to create a dataloader class that will use our custom dataset class. This is a common pattern in pytorch.

A pytorch dataset tells the number of samples in the dataset and how to get just ONE sample from the dataset. It also tells how to apply transformations to the sample. This is where you want to apply most of your data processing logic.

In [None]:
test_ds = ProstateDataset(folder='data/test', transform=valid_transforms)

In [None]:
len(test_ds) # number of samples in the dataset

In [None]:
example_image, example_mask = test_ds[24]
example_image.shape, example_mask.shape # shapes will be [channels, height, width] for both image and mask

In [None]:
# lets see our example sample
plot_one_example(example_image, example_mask)

In [None]:
test_loader = DataLoader(test_ds, batch_size=16, shuffle=True)


A pytorch dataloader tells how to get a batch of samples from the dataset.
You can;
- specify the batch size
- the number of workers to use for loading the data
- select to shuffle the data or not
- etc.

### What is a batch?

You can train a model with one example at a time. But it is not efficient. So we use batches to compute gradients of all the samples in the batch at once. Then we update the model weights with the average of the gradients of all the samples in the batch. This also helps with generalization. Because we are averaging out the noise in the gradients.

A batch is a collection of samples. For example, if you have 1000 samples and you set the batch size to 10, then you will have 100 batches. Each batch will contain 10 samples.


# The Model

Finally, the juicy part. We will use a pretrained U-Net model to predict on our dataset. U-Net is a popular architecture for semantic segmentation. It is a convolutional neural network that is used for image segmentation. It was first introduced by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in the paper U-Net: Convolutional Networks for Biomedical Image Segmentation in 2015.

A U-Net consists of two parts:
- Contracting path (left side of the U)
- Expanding path (right side of the U)

![unet-architecture](./assets/unet-architecture.png)

This picture looks a bit complicated. But it is not. Let's break it down.

The contracting path is a typical convolutional network that consists of repeated application of convolutions, each followed by a rectified linear unit (ReLU) and a max pooling operation. During the contraction, the spatial information is reduced while feature information is increased. The expanding path combines the feature and spatial information through a sequence of up-convolutions and concatenations with high-resolution features from the contracting path.

While getting lower in the network, number of channels are increased and the height and width of the image is decreased.

But we are not completely losing the spatial information. We are storing the spatial information in the expanding path. We are concatenating the feature maps from the contracting path to the feature maps in the expanding path. This way we are combining the feature maps from the contracting path with the spatial information from the expanding path.

Also one hidden benefit of using skip connections is that it helps with vanishing gradient problem. This problem is mostly encountered in deep neural networks. 

<!-- ![vanishing-gradient-problem](./assets/vanishing-gradients.jpeg) -->

<img src="./assets/vanishing-gradients.jpeg" alt="vanishing-gradient-problem" width="500"/>

The gradient signal goes through many layers during backpropagation, and the gradient signal becomes smaller and smaller as it goes deeper and reaches the first layers (closer to the input layer). This is called the vanishing gradient problem. This problem makes training deep networks hard.

In [None]:
model = UNet(
    spatial_dims=2, # 2d image
    in_channels=1,  # we only used  T2 weighed MRI images
    out_channels=3, # 3 labels
    channels=[16, 32, 64, 128, 256, 512], # number of channels to use while contracting
    strides=(2, 2, 2, 2, 2), # CNN strides
    num_res_units=4, # residual connections
    dropout=0.15, # dropout rate
)
# the monai library returns a pytorch model, so we can use it as a pytorch model

In [None]:
print_model_info(model)

In [None]:
# now our model class only has the logic and operations, we need to load the weights
model.load_state_dict(torch.load('best_model.pth'))

# Selecting a device to run the model

A device is used for accelerating the training process. You need to be aware of your devices memory capacity. This notebook runs smoothly in **Colab T4 GPU**.

Pytorch supports both CPU and GPU. You can select which device to use with the below code. If you have a GPU, you should use it. Because it is much faster than CPU. But if you don't have a GPU, you can still run this notebook on CPU. We are just making inference after all.

> You can also use a Apple MPS (Metal Performance Shaders) to train models on Apple devices (especially if you have M1 or M2 chips instead of GPU cards). Simply change below line with:

```python
device = torch.device('mps' if torch.mps.is_available() else 'cpu')
```

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# lets move the model weights and biases to the device memory
model = model.to(device)

In [None]:
# lets make the model prediction on the test set using the predict function
x, y, y_hat = predict(
    model,
    test_loader=test_loader,
    device=device,
    final_activation="softmax",
    calculate_scores=True,
)

In [None]:
plot_predictions(x, y, y_hat, num_examples_to_plot=5)