<a href="https://colab.research.google.com/github/rahiakela/modern-computer-vision-with-pytorch/blob/main/9-image-segmentation/1_semantic_segmentation_with_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Semantic Segmentation with U-Net

In this notebook, we will learn about semantic segmentation by taking a look at the U-Net architectures. Specifically, we will cover the following topics:

- Exploring the U-Net architecture
- Implementing semantic segmentation using U-Net

A succinct image of what we are trying to achieve through image segmentation (https://arxiv.org/pdf/1405.0312.pdf) is as follows:

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/image-segmentation.png?raw=1' width='800'/>

## Setup

In [None]:
!pip install -q torch_snippets pytorch_model_summary

In [None]:
import os
import torch
import torch.nn as nn
from torch_snippets import *
from torchvision import transforms
from torchvision.models import vgg16_bn
from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
%%shell

wget -q https://www.dropbox.com/s/0pigmmmynbf9xwq/dataset1.zip
unzip -q dataset1.zip
rm dataset1.zip

## Exploring the U-Net architecture

Imagine a scenario where you've been given an image and been asked to predict
which pixel corresponds to what object. 

So far, when we have been predicting the class of an object and the bounding box corresponding to the object, we passed the image through a network, which then passes the image through a backbone architecture (such as VGG or ResNet), flattens the output at a certain layer, and connects additional dense layers before making predictions for the class and bounding box offsets. 

However, in the case of image segmentation, where the output shape is the same as that of the input image's shape, flattening the convolutions' outputs and then reconstructing the image might result in a loss of information.

Furthermore, the contours and shapes present in the original image will not vary in the output image in the case of image segmentation, so the networks we have dealt with so far (which flatten the last layer and connect additional dense layers) are not optimal when we are performing segmentation.

The two aspects that we need to keep in mind while performing segmentation are as follows:

- The shape and structure of the objects in the original image remain the
same in the segmented output.
- Leveraging a fully convolutional architecture (and not a structure where
we flatten a certain layer) can help here since we are using one image as
input and another as output.

The U-Net architecture helps us achieve this. A typical representation of U-Net is as follows (the input image is of the shape `3 x 96 x 128`, while the number of classes present in the image is 21; this means that the output contains 21 channels):

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/unet.png?raw=1' width='800'/>

**The preceding architecture is called a U-Net architecture because of its "U"-like shape.**

In the left half of the preceding diagram, we can see that the image passes through convolution layers, and that the image size keeps reducing while the number of channels keeps increasing. However, in the right half, we can see that we are upscaling the downscaled image, back to the original height and width but with as many channels as there are classes.

**In addition, while upscaling, we are also leveraging information from the
corresponding layers in the left half using skip connections so that we can preserve the structure/objects in the original image.**

**This way, the U-Net architecture learns to preserve the structure (and shapes of objects) of the original image while leveraging the convolution's features to predict the classes that correspond to each pixel.**

In general, we have as many channels in the output as the number of classes we want to predict.

### Performing upscaling

In the U-Net architecture, upscaling is performed using the nn.ConvTranspose2d
method, which takes the number of input channels, the number of output channels, the kernel size, and stride as input parameters.

<img src='https://github.com/rahiakela/img-repo/blob/master/object-detection-images/upscaling.png?raw=1' width='800'/>

In the preceding example, we took an input array of shape `3 x 3` (Input array), applied a stride of 2 where we distributed the input values to accommodate the stride (Input array adjusted for stride), padded the array with zeros (Input array adjusted for stride and padding), and convolved the padded input with a filter (Filter/Kernel) to fetch the output array.

In order to understand how nn.ConvTranspose2d helps upscale an array, let's go
through the following code:

In [3]:
# Initialize a network, m, with the nn.ConvTranspose2d method by specifying input channel's value is 1, output channel's value is 1
m = nn.ConvTranspose2d(1, 1, kernel_size=(2, 2), stride=2, padding=0)

Internally, padding is calculated as `dilation * (kernel_size - 1) - padding`.

Hence `1*(2-1)-0 = 1`, where we add zero padding of 1 to both dimensions of
the input array.

In [4]:
# Initialize an input array and pass it through the model
input = torch.ones(1, 1, 3, 3)
output = m(input)
print(output.shape)

In [5]:
output

tensor([[[[ 0.5472, -0.1862,  0.5472, -0.1862,  0.5472, -0.1862],
          [-0.0158,  0.4669, -0.0158,  0.4669, -0.0158,  0.4669],
          [ 0.5472, -0.1862,  0.5472, -0.1862,  0.5472, -0.1862],
          [-0.0158,  0.4669, -0.0158,  0.4669, -0.0158,  0.4669],
          [ 0.5472, -0.1862,  0.5472, -0.1862,  0.5472, -0.1862],
          [-0.0158,  0.4669, -0.0158,  0.4669, -0.0158,  0.4669]]]],
       grad_fn=<SlowConvTranspose2DBackward>)

## Implementing semantic segmentation using U-Net