# Segmentation

This notebook regarde semantic segmentation using the [PascalVOC dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) dataset and PyTorch library. The goal of semantic image segmentation is to label each pixel of an image with a corresponding class of what is being represented.

Let's first install and load the library we will use throught the notebook.

In [None]:
!pip install torchmetrics
!pip install monai
import os
import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import time
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.transforms.functional import InterpolationMode
import torchmetrics
from torch.utils.data import DataLoader
from PIL import Image
import torchvision
from torchvision.models.segmentation import FCN_ResNet50_Weights
import torch.nn.functional as F
import monai
import torchvision.transforms.functional as TF

# Data

In this notebook we will use the [PascalVOC dataset](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) for semantic segmentation. 20 classes are defined (+ background) and an additional label (255) is used for the void category, i.e. border regions (5px) and mask for difficult objects. In the following we will consider these pixels as background and therefore we will have 21 classes in total.

We define here some auxiliary functions to:
<ul>
    <li> Change the pixel vaues of void category to background</li>
    <li> Plot a group of images </li>
</ul>

In [None]:
def replace_tensor_value_(tensor, a, b):
    tensor[tensor == a] = b
    return tensor


def plot_images(images, num_per_row=8, title=None, vmax=None):
    num_rows = int(math.ceil(len(images) / num_per_row))

    fig, axes = plt.subplots(num_rows, num_per_row, dpi=150)
    fig.subplots_adjust(wspace=0, hspace=0)
    for image, ax in zip(images, axes.flat):
        ax.imshow(image, vmax=vmax)
        ax.axis('off')

    return fig

Now we define the classes for PascalVOC segmantic segementation (from 0 to 20) and the color palette used for their representation in RGB. (These variables are not used in the following but can be useful if you want to further experiment with this dataset)

In [None]:
VOC_CLASSES = [
    "background",
    "aeroplane",
    "bicycle",
    "bird",
    "boat",
    "bottle",
    "bus",
    "car",
    "cat",
    "chair",
    "cow",
    "diningtable",
    "dog",
    "horse",
    "motorbike",
    "person",
    "potted plant",
    "sheep",
    "sofa",
    "train",
    "tv/monitor",
]
# Color palette for segmentation masks in RGB
PALETTE = np.array(
    [
        [0, 0, 0],
        [128, 0, 0],
        [0, 128, 0],
        [128, 128, 0],
        [0, 0, 128],
        [128, 0, 128],
        [0, 128, 128],
        [128, 128, 128],
        [64, 0, 0],
        [192, 0, 0],
        [64, 128, 0],
        [192, 128, 0],
        [64, 0, 128],
        [192, 0, 128],
        [64, 128, 128],
        [192, 128, 128],
        [0, 64, 0],
        [128, 64, 0],
        [0, 192, 0],
        [128, 192, 0],
        [0, 64, 128],
    ]
    + [[0, 0, 0] for i in range(256 - 22)] # Maybe not used 
    + [[255, 255, 255]], # Maybe not used either
    dtype=np.uint8,
)

# Download and visualize data

In the following we first download the data using the datasets class VOCSegmentation provided by PyTorch in the torchvision module (https://pytorch.org/vision/stable/generated/torchvision.datasets.VOCSegmentation.html#torchvision.datasets.VOCSegmentation)

In [None]:
# Creating (and downloading) the dataset
train_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='train'
)
val_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='val'
)

In [None]:
# Some info regarding the dataset
print(f"Number of training images: {len(train_dataset)}")
print(f"Number of validation images: {len(val_dataset)}")
# Plot some sample
inputs, ground_truths = map(list, zip(*[val_dataset[i] for i in range(32)]))
_ = plot_images(inputs)
_ = plot_images(ground_truths)

# Data Transformation 

Pytorch provides several function to transform and augment images in the module [torchvision.transforms](https://pytorch.org/vision/stable/transforms.html). Here, we will first resize the images to a fixed size and after that we will apply normalization using some standard values using in semantic segmentation. For label images we also replace the void label (255) with background (0). Finally, we load again the datasets providing the created segmentation for inputs and targets.

In [None]:
image_size = 520 # We will resize to a standard size
image_mean = [0.485, 0.456, 0.406]  # These are some standard for image noramization 
image_std = [0.229, 0.224, 0.225]  # in deep learning
input_transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize(image_mean, image_std),
    ]
)
target_transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.NEAREST),
        transforms.PILToTensor(),
        transforms.Lambda(lambda x: replace_tensor_value_(x.squeeze(0).long(), 255, 0)),
    ]
)
# Creating (and downloading) the dataset
train_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='train',
    transform=input_transform,
    target_transform=target_transform,
)
val_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='val',
    transform=input_transform,
    target_transform=target_transform,
)

# Let's try some pre-trained models

Pytorch provides several pre-trained models in the module [torchvision.models](https://pytorch.org/vision/stable/models.html). In this notebook, we are interested in semantic segmentation models which can be found at https://pytorch.org/vision/stable/models.html#semantic-segmentation. Let's try the Fully-Convolutional Network model with a ResNet-50 backbone trained on another dataset with the 20 categories of PascalVOC (from the [Fully Convolutional Networks for Semantic Segmentation paper](https://arxiv.org/abs/1411.4038)). We first plot some results compared with the ground truth and then we use some metrics to evaluate the accuracy of the segmentation.

#### Note: the cell below should print cuda:0, otherwise change the runtime on Google Colab

In [None]:
# Set device 
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('The current processor is ...', device)

In [None]:
# Dataloader are iterative objects than automatically group data samples in batches 
# and return them with the transformation specified in the dataset
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

#Exercise:

Check the model description below and try to understand its components. Focus on the Bottleneck component(s), which of the two residual blocks represented in figure 5 of the [ResNet paper](https://arxiv.org/pdf/1512.03385.pdf) is implemented?

Note: In PyTorch a convolution with $I$ input channels and $O$ ouput channels is defined as Conv2D($I$, $O$)

# Answer

Write here

In [None]:
# Load pre-trained weights and create the model
weights = FCN_ResNet50_Weights.DEFAULT
model = torchvision.models.segmentation.fcn_resnet50(weights=weights)
print(model)

Now we try to feed some samples and visualize the resulting segmentation.

In [None]:
# Move model to device
model.to(device)
model.eval()
# In val, disactivate gradient computation to save memory
with torch.no_grad():
  for batch in valid_loader:
      data, label = batch
      # Move data to device
      data, label = data.to(device), label.to(device)
      output = model(data)
      # The second output is known as an auxiliary output and is contained 
      # in the AuxLogits part of the network. We use just the 'out'.
      prediction = output['out'].argmax(dim=1).squeeze()
      print("Predictions")
      _ = plot_images(prediction.cpu(), vmax=21, title='prediction')
      plt.show()
      print("Label")
      _ = plot_images(label.cpu(), vmax=21, title='label')
      plt.show()
      break

# Dice accuracy
Let $T_P$, $F_N$ and $F_P$ be the true positives, false negatives and false positives, we define the Dice accuracy as:

$$ Dice = \frac{2 T_P}{2 T_P + F_N + F_P}$$

We use the multi-class implementation of Dice accuracy provided by the torchmetrics library. Check the documentation here https://torchmetrics.readthedocs.io/en/stable/classification/dice.html.

In [None]:
# Start metric
dice = torchmetrics.Dice().to(device)
model.eval()
with torch.no_grad():
  for idx, batch in enumerate(valid_loader):
      start = time.time()
      data, label = batch
      data, label = data.to(device), label.to(device)
      output = model(data)
      logits = output['out']
      # Add this batch to metric computation
      dice.update(logits, label)
      print(f"Batch {idx} inference in {time.time() - start} seconds")
# Compute the total Dice on the whole validation set
print(f"Average DICE accuracy on validation set: {dice.compute().item()*100} %")
# Important to reset and free memory
dice.reset()

# Exercise:

Based on the previous cells try some other pretrained model from the torchvision.models module. Check the avialble modlues and weghts here https://pytorch.org/vision/stable/models.html#semantic-segmentation. Are the model you tried better or worst than the Fully-Convolutional Network model with a ResNet-50 backbone? Do you see any difference in the predictions? Why?

In [None]:
# TODO: write the code here to test another pre-trained model as previously done

# Let's train a custom module (UNet)

In the following we manually implement the UNet model for image segmentation and we train it from scratch using the standard PyTorch pipeline.

First we define the modules for each part of the UNet and after we define the complete model from https://github.com/milesial/Pytorch-UNet.

In [None]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, channel_in, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.channel_in = channel_in
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(channel_in, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

Now we create and print our UNet model, the inputs channels are 3 (RGB) and the output calsses are 21 (20 objects + background)

In [None]:
# Create the model and print it
model = UNet(channel_in=3, n_classes=21, bilinear=True)
print(model)

We also load again the data with different resizing and batch size to don't fill the GPUs memory and have a 'CUDA out of memory' error.

In [None]:
image_size = 224 # We will resize to a standard size
image_mean = [0.485, 0.456, 0.406]  # These are some standard for image noramization 
image_std = [0.229, 0.224, 0.225]  # in deep learning
input_transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize(image_mean, image_std),
    ]
)
target_transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size), interpolation=InterpolationMode.NEAREST),
        transforms.PILToTensor(),
        transforms.Lambda(lambda x: replace_tensor_value_(x.squeeze(0).long(), 255, 0)),
    ]
)
# Creating (and downloading) the dataset
train_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='train',
    transform=input_transform,
    target_transform=target_transform,
)
val_dataset = datasets.VOCSegmentation(
    './datasets/',
    year='2012',
    download=True,
    image_set='val',
    transform=input_transform,
    target_transform=target_transform,
)
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valid_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

We define in the following all the parameters that will be used in the training. Namely:

<ul>
  <li> Learning rate </li>
  <li> Number of epochs </li>
  <li> Optimizer: Adam </li>
  <li> Loss function:  Cross Entropy (CE) loss.</li>
  <li> Function to compute the validation accuracy (DICE as before)</li>
</ul>


In [None]:
learning_rate = 1e-5
epochs = 20
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()
dice = torchmetrics.Dice().to(device)

#Exercise:

We impmenet int he next cell, the classic PyTorch pipeline, structured as follow to perform the training of the UNet. Fill the missing code blocks (TODO). After that try to change the learning rate and the number of epochs, what is the impact on the final result ? If you want apply also other changes in the model or loss function to try to improve the results. It is not easy to find the good combination of parameters giving good results, but you will see the loss decreasing during training and the accuracy increasing anyway

In [None]:
model = model.to(device)
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(epochs):  # loop over the dataset multiple times

    # One train set iteration
    running_loss = 0.0
    last_train_loss = 0.0
    model.train()

    for idx, batch in enumerate(train_loader):

      # TODO: Fetch batch data as in the previous validation loops

      labels = labels.unsqueeze(1) # We need this to have labels as [B, 1, W, H]

      # zero the parameter gradients
      optimizer.zero_grad()

      # forward
      # TODO: Feed data into the network and store the ouput
      # TODO: Compute loss using the criterion previously defined

      # backward + optimize
      loss.backward()
      optimizer.step()

      # print statistics
      running_loss += loss.item()
      if idx % 10 == 9:    # print every 10 mini-batches
          print(f'[{epoch + 1}, {idx + 1:5d}] loss: {running_loss / 10:.3f}')
          last_train_loss = running_loss / 10
          running_loss = 0.0

    train_losses.append(last_train_loss)
    print(f"Final test set loss for epoch {epoch + 1}: {last_train_loss:.3f}")

    # Validation loop
    model.eval()
    val_loss = 0.0

    # TODO: Validation loop (see previous cells), accumulate the loss and the accuracy over the entire validation loader
    
    # AFter the validation iteration ends compute the total Dice on the whole validation set
    val_loss = val_loss / len(valid_loader)
    val_acc = dice.compute().item()*100
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    print(f"Loss on validation set for epoch {epoch + 1}: {val_loss:.3f}")
    print(f"Average DICE accuracy on validation set for epoch {epoch + 1}: {val_acc} %")
    dice.reset()

print('Finished Training')

# Save model
torch.save(model.state_dict(), 'models/unet_v0.pt')

# If you want load the trained model see https://pytorch.org/tutorials/beginner/saving_loading_models.html

# Plot the loss and accuracy graphs of the training

In [None]:
# Plot Train vs Val loss over epochs:
plt.figure()
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='train')
plt.plot(val_losses, label='val')
plt.ylabel("Loss")
plt.xlabel("Epochs")
plt.legend()
plt.title("Loss vs Epochs")

plt.subplot(1, 2, 2)
plt.plot(val_accuracies)
plt.ylabel("Accuracy [%]")
plt.xlabel("Epochs")
plt.title("Validation accuracy")
plt.show()

# Show segmentation of one validation batch

In [None]:
model.eval()
# In val, disactivate gradient computation to save memory
with torch.no_grad():
  for batch in valid_loader:
      data, labels = batch
      # Move data to device
      data, labels = data.to(device), labels.to(device)
      output = model(data)
      # The second output is known as an auxiliary output and is contained 
      # in the AuxLogits part of the network. We use just the 'out'.
      prediction = output.argmax(dim=1).squeeze()
      print("Predictions")
      _ = plot_images(prediction.cpu(), vmax=21, title='prediction')
      plt.show()
      print("Label")
      _ = plot_images(labels.cpu(), vmax=21, title='label')
      plt.show()
      break

#Additional Exercise:

Add the parameter n_channels to the UNet initialization and define the convolutions input and output sizes proportional to this paramtetr. 

```
class UNet(nn.Module):
    def __init__(self, channel_in, n_classes, n_channels, bilinear=False):
        super(UNet, self).__init__()
        ...
```

####Note: at each layer of UNet the features size, i.e. output channels, is doubled. Try to use the parameter n_channels as ouptput size of the first convolutions and make the followings proportional.

In [None]:
# TODO: write the new model class here