# Semantic Segmentation using SegNet 
## Introduction
This notebook is a summary of a personal project demonstrating how convolutional neural networks (CNNs) can be used for an image segmentation task. Specifically, we describe semantic labelling of specific targets within a dataset. Semantic segmentation is a common segmentation technique where different objects are distinguised from each other based on class type only, though each instance within the class is not delineated. 

<br>
<img src="typesofseg.png" width=1000 height=1000 />


##  Dataset 
I've been particularly interested in how supervised learning can be used in digital pathology and other medical/research image modailties. For this project, I used a **darkfield microscopy dataset** populated with labelled red blood cells (RBC's) and Spirochaeta bacteria. My goal was to construct and test a CNN architecture from the groud up which could label RBC's from bacteria within reasonable accuracy. For me, automating labelling of bacterium species seemed like a great starting point for exploring the capabilities of supervised learning in a diagnostic capacity. 

I later discuss the results on my implementation in PyTorch. 

### Images and mask labels

A few examples from the dataset of of colorized cells and bacterium species (left) and corresponding labelled masks which would be used for training (right) 
<br>
<img src="1.png" width=400 height=400 />
<img src="2.png" width=400 height=400 />
<img src="3.png" width=400 height=400 />

### Overview
<img src="overview.png" width=700 height=700 />

## SegNet 
SegNet is an CNN configuration originally developed at the University of Cambridge for semantic segmentation. At the time of its publication, SegNet introduced a rather novel *encoding-decoding architecture* compared to its contemporaries, which used fully connected layers (and consequently many more parameters to train.) As for why I chose this architecture, it came down to my personal interest with its implementation. 

### Architecture
The encoding - decoding stragety aims to reduce the number of learnable parameters by sucessively reducing spatial resolution in the encoder stage, followed by restoring this resolution for the final label projection. We trade spatial resolution for feature map depth at each encoder stage to encourage the network to infer more complex abstractions (ie. patterns in the image data.) The decoder network inverts this process by restoring resolution and collapsing feature map depth to project a final prediction. 

The complete architecture is illustrated below from the original SegNet paper: 
<img src="segnet_architecture.png" width=800 height=800 />

- **Convolutional layer**: 
A layer where the input image is convolved with a pre-determined filter of size K, stride S, and padding P. Convolutions are essentially linear combination operations with the filter providing the weights. 
<img src="05_convolutions_example.gif" width="500" align="center"> _Stride_ refers to how many units the kernel jumps either vertically or horizontally along the input image during each convolution operation. _Padding_ refers to how many rows and columns of zeros are added around the input matrix prior to convolution. Since colvolutions inherently reduce the input size, we can pad in case we would like the same output dimension. The filter weights are leanable parameters, meaning they can be modified as part of the networks learning process (back propagation.) 


- **Batch normalization layer**: 
A layer which normalizes the input images provided to it. This layer typically accepts a 'mini-batch' of images prior to normalizing the images using their mean and standard deviation. This normalization is important so that the feature spaces are comprable in range, as otherwise we run the risk of encoutering *vanishing and exploding gradients* in the backpropagation step (i.e gradients are too small/large and training slows down or effectively stops.) 


- **ReLU**: 
Short form for rectified linear unit function. It is an element-wise activation function which applies a very simple scaling to its input: <img src="relu.png" width=300 height=300 />
Each element in an input layer fed into a ReLU is subject to this activation function. This function is only piecewise linear (note the elbow at the origin.) ReLU adds a non-linearity to the network and allows the network to learn non-linear relationships in the data.


- **Max Pooling and Max Unpooling**: 
_Max Pooling_ is a downsampling operation where the maximum element is taken in a sliding kernel of size K, with stride S, and padding P (of the input image.) Max pooling is important in the SegNet architecture as part of the encoder stage to reduce spatial resolution. _Max Unpooling_ is the inverse of this operation, which is significant in the decoder stage where we attempt to restore spatial resolution prior to our pixel-wise class predictions. <img src="pooling.png" width="500" align="center">


- **Softmax**:  
Function which turns an input vector of 'n' values into values that add up to '1'. The softmax is best demonstrated using a simple example: Assume some vector K = [a, b, c] 
### $Softmax(k) =$ [$\frac{e^a}{e^a + e^b + e^c}$, $\frac{e^b}{e^a + e^b + e^c}$, $\frac{e^c}{e^a + e^b + e^c}$] 

    The softmax function has the effect of scaling values into a probability distribution, where negative values return small pobabilities, and vice versa. As such, we can use it to produce classification decisions for any number of mutually exclusive classes. Since our output from SegNet will contain one channel per class,  we use it to produce a probability distribution for each channel, with each channel representing the probability that a pixel belongs to its class. Once we obtain this, we simply label any pixel based on which channel assigned the highest probability to it. 

## Training and validation using PyTorch
The complete code can be found at under the SegNet folder. Here, I describe the complete workflow of my PyTorch implemetation, including the transformations used on the training data, the performance of the model, and the decisions I made in tuning the model. 

### Pytorch overview
Our PyTorch SegNet project consists of a several key components: 
1. Dataset
2. DataLoaders
3. Data transformations
4. Prepping for data imbalances
5. Model architecture
6. Loss function and optimizer
7. Hyperparameters
8. Validating accuracy

#### 1. Dataset
I write a custom dataset named *SegNetDataSet* to load my images and target masks, inheriting from the *Dataset* class. My dataset class accepts transform arguments for both the images and target masks separately, since there are some transformations I do not apply to the target masks (ex: no batch normalization to the target mask.)  

```python
# Load custom dataset
dataset = SegNetDataSet(r'C:\Users\vajra\Documents\GitHub\ML_playground\PyTorch\segnet\archive', 
                        data_transforms=data_transforms, target_transforms=target_transforms)
```

#### 2. DataLoaders
I use the standard DataLoader class to create two loaders, namely a training and test loader for the training and test sets respectively. 

```python
# Produce test and train sets
train_set, test_set = torch.utils.data.random_split(dataset, [329, 37]) # 90% 10% split between train and test 

train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)
```

#### 3. Data Transformations
I pass a composition of transformations into my custom dataset to transform the images and target masks when they are loaded.

```python
# Compose transformations 
data_transforms = transforms.Compose([
    transforms.Resize((256,256)),   
    transforms.ToTensor(),  
    transforms.Normalize( mean = [0.1600, 0.1959, 0.2559], 
                          std=[0.2209, 0.2456, 0.2530] )
    ])   
```
For the darkfield images, we resize to a smaller value, follwed by normalizing the entire dataset over three channels (RGB.) I compute these values separately by interating through the images:

```python
train_loader = DataLoader(dataset=train_set, batch_size=len(train_set))
data, targets = next(iter(train_loader))
means = data[:, 0, :, :].mean(), data[:, 1, :, :].mean(), data[:, 2, :, :].mean()
stds = data[:, 0, :, :].std(), data[:, 1, :, :].std(), data[:, 2, :, :].std()
```
Note that for this train loader, I set the batch size equal to the total number of images. So these means and standard deviations are computed for each color channel over the entire set of images. 


I pass a separate transform for the target images: 
```python
target_transforms = transforms.Compose([
    transforms.Resize((256,256)),   
    ]) 
```
We only need the masks to match the size of our image data, we do not normalize our mask (we need to retain the labels.) 

#### 4. Prepping for data imbalances
The darkfield microscopy dataset consists of 3 classes: background, cells and bacteria. Each is labelled 0,1,2 respectively. The class distribution is highly imbalanced (i.e not uniform.)
I present this imbalance by summating the total number of pixels belonging to each class in each image: 

```python
class_1 = class_2 = class_3 = 0
for i in range(1,367):
    image = Image.open('archive/flattened_masks/' +  str(i).zfill(3) + '.png')
    imm2arr = np.array(image)
    class_1 += len(np.where(im2arr==0)[0])
    class_2 += len(np.where(im2arr==1)[0])
    class_3 += len(np.where(im2arr==2)[0])  
```

Iterating through all the masks we obtain a distribution of class frequency: 
```python
class_1 = 161442628
class_2 = 24498688
class_3 = 2006574
```

Its clear from these numbers that the background and cell classes dominate, while the bacteria class is much more sparse.
This presents an issue from a learning standpoint; the network will perform worse at segmenting sparse classes, since it has fewer examples of that class to learn. We can account for this imbalance by **weighting the classes at the loss function.** PyTorch allows us to pass weights into its loss functions to weigh how a class is penalized. In our case, we can set a weight inversely proportional to the class frequency. This can be achieved through a simple weighting computation: 

```
cFreq = [class_1, class_2, class_3]
normalized_weights = [1 - (x / sum(cFreq)) for x in cFreq]
```

We can then pass these normalized weights into the loss function: 
```python
criterion = nn.CrossEntropyLoss(normalized_weights)
```

#### 5. Model architecture
For my model, I use 3 encoder and decoder stages rather than 4. The images themselves do not represent very complex artifacts, so I felt it would be in my best interest to try a smaller model and see what kind of performance I could get. 
I first define a class which inherits from _nn.Module._ This allows me to assign convolutional layers to my class attributes.

```python
class SegNet(nn.Module):
    
    def __init__(self, in_channels, num_classes):
        super(SegNet, self).__init__()
        
```

Within this class, I define my encoder and decoder blocks. I demonstrate only the middle encoder and decoder blocks below for simplicity: 

```python
# ... 
# Encoder block 3
self.encoder_conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.encoder_bn5 = nn.BatchNorm2d(256)
self.encoder_conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.encoder_bn6 = nn.BatchNorm2d(256)

self.encoder_mp3 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), return_indices=True)
# ============================ # 
# Decoder block 1
self.decoder_mup1 = nn.MaxUnpool2d(kernel_size=(2, 2), stride=(2, 2))

self.decoder_conv1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.decoder_bn1 = nn.BatchNorm2d(128)
self.decoder_conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.decoder_bn2 = nn.BatchNorm2d(128)
# ...
```

For each block, I define 2 convolutional layers, and 2 batch norm layers. We set kernel size, stride and padding such that our *output does not change size between convolutions (this is the job of the max pool / max unpool layers.)* The output size is determined by a simple formula: 

$ Output~Size = \frac{(W−K+2P)}{S}+1 $

W = input size

K = kernel size

P = padding

S = Stride

The maxpool and max unpool layers change increase / decrease the spatial resolution by a factor of 2 respectively (i.e half the area for maxpool, and double the area for max unpool.) 
Notice the number of output channels at the end of encoder 3 is also equal to the number of input channels at the start of the decoder. This is true between all blocks. The output channel number of the last must match the input channel number of the next block.

#### 6. Loss function and optimizer
***Cross Entropy Loss***

For our loss function, we employ **Cross Entropy Loss.** Cross entropy loss is commonly used for classification. In PyTorch, Cross Entropy Loss performs three operations: 
1. The first is a softmax operation per channel on the output. There is one channel per class in the model output.
2. This is followed by collapsing the channels into a single channel, where each pixel is labelled based on whichever channel returned the highest softmax probability when its position is compared across all the channels. 
3. Finally, the loss is computed per pixel via its softmax value and its binary truth label (determined by comparing to the true mask label.) We summarize these steps below: 

<img src="CE1.png" width="550" align="center">
<img src="CE2.png" width="550" align="center">

Cross Entropy Loss penalizes heavily for correct classifications ($t_{i}~=~1$) with low confidence ($p_{i}~<<~1$.)
#### $ - (1)~*~log(0.01) = 2 $ (high loss)
Note that typically we compute the average loss over a **mini batch** of images. Each image porduces a loss value, and the average of these losses are taken per mini batch. 

***Adam***

Adam is a common choice as an optimizer strategy for gradient descent. At its core, it involves combining two optimizer methods, namely *momentum* and *Root Mean Square propagation (RMS Prop)*. We do not delve into the finer details of the math in this notebook, but the main idea of each is described below: 

*Momentum* involves taking an exponentially weighted average of the computed gradients. In essence, each new gradient computation is multiplied by a constant between zero and 1 (i.e weighted) and is then cumulatively added to all the past weighted gradients up to that point. This new value then becomes the current 'weighted average' gradient. Momentum has the effect of stabilizing gradient changes, since no one training example may drastically change the gradient value. 

*RMS Prop* is similar to momentum, in that it makes use of an exponentially weighted average of the square of the computed gradients. Unlike momentum; however, it is an adaptive learning algorithm. RMS Prop modifies the learning rate parameter by dividing the learning rate with this exponentially weighted average. This has the effect of increasing the learning rate when the gradients are small, which encourages the network to take larger steps in case it is at a local minima (and maybe not a global minima.) Conversely, RMS Prop will reduce the learning rate when the gradients are big, so that it is less likely for the network to overshoot a minima. 

#### 7. Hyperparameters
Our model relies on a few hyperparameters which we can tune.

```python
learning_rate = 0.01
batch_size = 16
num_epochs = 50
```
I was limited in GPU memory, so I maintain a small batch size of 16. A small batch size also seemed reasonable given the relatively small dataset. I also retain a default learning rate of 0.01. The Adam optimizer adapts the learning rate as the network trains so I did not experiment too much with different base learning rate values. 

I did, however, try many different epoch values and compared how the segmentation accuracy changed with higher epoch numbers. We discuss our findings later on in the notebook. 

#### 8. Validating accuracy
I compute segmentation accuracy on a per class basis. For accuracy, I focus on two key criteria:

For each class: 
1. Proportion of pixels classified correctly (ammount of mutual overlap.)
2. Proportion of pixels classified incorrectly (predictions extraneous to the true label areas.)

While we aim to maximize proportion of correctly classified pixels, this criteria alone is a poor indicator of segmentation performance. For example, we may attain very high accuracy for the bacteria class if our network classifies every single pixel as a bacteria class. As such, we introduce a penalty for any extraneous predictions outside the true mask. The penality is a simple subtration of proportion of incorrectly classified area. I illustrate this approach below: 


<table><tr>
<td> <img src="A1.png" alt="Drawing" style="width: 500px;"/> </td>
<td> <img src="A2.png" alt="Drawing" style="width: 500px;"/> </td>
</tr></table>

```python
for n in range(num_classes):
    prediction_class_pos = prediction == n
    true_label_class_pos = true_label == n
    matching = prediction_class_pos & true_label_class_pos
    unmatching = prediction_class_pos ^ true_label_class_pos # xor

    if true_label_class_pos.sum() != 0:
        accuracy=matching.sum()/true_label_class_pos.sum()-unmatching.sum()/len(unmatching.flatten())
    else:
        accuracy = 1 - ( (unmatching.sum() / len(unmatching.flatten())) )
```

### Evaluating model performance 
As expected, loss follows a downward trajectory with greater epoch number: 

<img src="loss.png" width="500" align="center">

We pay specific attention to any divergences in segmentation accuracy as the epoch numbers rise (indicating overfitting.) Note that we only show RBC and bacteria class, ignoring background class. 

<table><tr>
<td> <img src="rbc_eval.png" alt="Drawing" style="width: 500px;"/> </td>
<td> <img src="bact_eval.png" alt="Drawing" style="width: 500px;"/> </td>
</tr></table>

From our data, we can deduce that epoch 35 is where training and testing accuracy begin to diverge and we begin to loose good generalization between the test and training set. Thus, we save the model around epoch 35. 

```python
def save_model_at_checkpoint(state, epoch):
    filename = "model_at_epoch_" + str(epoch+1) + ".pth.tar"
    torch.save(state, filename)
    
 
# ... after epoch 35 in the training loop # 
state = {"model_state": model.state_dict(), "optim_state": optimizer.state_dict()}
save_model_at_checkpoint(state, epoch)
```

We can plot the images and prediction masks from the SegNet model: 

Examples for predictions on training set: 
<img src="train_1.png" width="400" align="center">
<img src="train_2.png" width="400" align="center">

Examples for predictions on test set: 
<img src="test_1.png" width="400" align="center">
<img src="test_2.png" width="400" align="center">

### Possible improvements 
