 # <div align = 'center'><u> Understanding TTA in Image Segmentation</u></div>

This notebook will explain you the idea behind TTA in segmentation in a simple way.

# Table of contents <a id='0.1'></a>

1. [What is TTA ?](#1)
2. [TTA in Image Classification](#2)
3. [TTA in Image Segmentation](#3)

In [None]:
!pip install segmentation-models-pytorch

# Copying diagrams to the current working directory
!cp ../input/tta-segmentation-diagram/TTA_classification.png .
!cp ../input/tta-segmentation-diagram/TTA_not_right_way.png .
!cp ../input/tta-segmentation-diagram/TTA_segmentation.png .

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from albumentations import *
import torch
import torch.nn as nn
from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch.encoders import get_preprocessing_fn

PATH_DATA = '../input/hubmap-256x256/'
PATH_TRAINED_MODEL = '../input/tta-segmentation-diagram/hubmap_8core_model_V25.pth'

# 1. <a id='1'>What is TTA ?</a>
[Table of contents](#0.1)

TTA stands for **Test Time Augmentation**. 

In TTA, predictions of the augmented test samples (A single test sample augmented multiple times) are averaged to get the final prediction. TTA boost the performance of the model.

Refer to this blog, if you are new to TTA : 
[How to Use Test-Time Augmentation to Make Better Predictions](https://machinelearningmastery.com/how-to-use-test-time-augmentation-to-improve-model-performance-for-image-classification/)

# 2. <a id='2'>TTA in Image Classification</a>
[Table of contents](#0.1)

We will understand TTA in segmentation, by understanding TTA in classification.

In image classification, model outputs a probability (score) given an input image. 
We simply average the predicted probability of the different augmented images of a test image.

Let's take an example of Melanoma Detection. Below diagram explains it well.

<center><img align="right" src="TTA_classification.png"></center>

If we would have only used orignal image for the inference, we would have got 0.88 probability. But by applying TTA, we got 0.913 score. 

Similarly, TTA in segmentation will give boost in a performance. 
**But in segmentation, we have to be careful and in some cases, we need to do a extra step while applying TTA.**

# 3. <a id='3'>TTA in Image Segmentation</a>
[Table of contents](#0.1)

In image segmentation, model outputs a grid of probability (score). 

*Note:- We are talking about binary segmentation, where we need to detect a blob in a input image.*

Let's take an example, where we are using Horizontal Flip (HF) and Vertical FLip (VF) as our augmentations in TTA.

`im` is our test image whose predicted mask needs to be generated. `im_hf` and `im_vf` are the augmented images of `im` by applying HF and VF respectively.

`pred1`, `pred2_hf` and `pred3_vf` are the predicted masks of the `im`, `im_hf` and `im_vf` images respectively.

Here's the catch, we cannot directly average this predicted masks, becasue two of these masks are the predictions of the augmented images, so these predicted masks are augmented too. (By averaging means, we are averaging the corresponding pixels of these masks) 

If we average these augmented masks, we will be averaging wrong pixels of these masks and we will end up like this, where we will get three detected blobs instead of one in the final prediction.  

<center><img height = 500, width = 500, src="TTA_not_right_way.png"></center>



So, we need to **deaugment the augmented predicted masks**, so that their pixels aligns with the original predicted mask and then we can average them to get the correct final prediction.

Corect procedure is demonstrated in the diagram below.

<center><img align="right" src="TTA_segmentation.png"></center>

Now, we will code this.

Utilities (Hidden)

In [None]:
def get_pred_mask(image):
    preprocessing_fn = Lambda(image = get_preprocessing_fn(encoder_name = ENCODER_NAME,
                                                       pretrained = 'imagenet'))
    im = preprocessing_fn(image=image)['image']
    im = np.moveaxis(im, -1, 0)
    im = torch.from_numpy(im)
    pred = model(im.float()[None])
    
    return pred.detach().numpy()[0][0]


def show(xs, titles, cmap = None):
    _, axs = plt.subplots(1, len(xs))
    
    for i, x in enumerate(xs):
        axs[i].imshow(x, cmap = cmap)
        axs[i].set_title(titles[i])

    plt.show()

We will be using a model trained in this [notebook](https://www.kaggle.com/joshi98kishan/training-pytorch-tpu-8-cores).

In [None]:
ENCODER_NAME = 'se_resnext50_32x4d'

class HuBMAPModel(nn.Module):
    def __init__(self):
        super(HuBMAPModel, self).__init__()
        self.model = Unet(encoder_name = ENCODER_NAME, 
                          encoder_weights = None,
                          classes = 1,
                          activation = None)
        
    def forward(self, images):
        img_masks = self.model(images)
        return img_masks

state_dict = torch.load(PATH_TRAINED_MODEL)
model = HuBMAPModel()
model.load_state_dict(state_dict)

In [None]:
# Reading a sample image. 
# Consider 'im' as a test image and 'mask' as ground truth.

im = plt.imread(os.path.join(PATH_DATA, 'train/095bf7a1f_861.png'))
mask = plt.imread(os.path.join(PATH_DATA, 'masks/095bf7a1f_861.png'))

show([im, mask], ['im', 'mask'], cmap = 'gray')

In [None]:
# Defining the augmentations

horizontal_flip = HorizontalFlip(p = 1.0)
vertical_flip = VerticalFlip(p = 1.0)

In [None]:
# Augmentation

im_hf = horizontal_flip(image = im)['image']
im_vf = vertical_flip(image = im)['image']

show([im, im_hf, im_vf], ['im', 'im_hf', 'im_vf'])

In [None]:
# Prediction

pred1 = get_pred_mask(im)
pred2_hf = get_pred_mask(im_hf)
pred3_vf = get_pred_mask(im_vf)

show([pred1, pred2_hf, pred3_vf], 
     ['pred1', 'pred2_hf', 'pred3_vf'],
     cmap = 'gray')

In [None]:
# Deaugmentation

pred2 = horizontal_flip(image = np.zeros((256, 256, 3), dtype = np.int), 
                          mask = pred2_hf)['mask']

pred3 = vertical_flip(image = np.zeros((256, 256, 3), dtype = np.int), 
                          mask = pred3_vf)['mask']


show([pred1, pred2, pred3], 
     ['pred1', 'pred2', 'pred3'],
     cmap = 'gray')

In [None]:
# Averaging

pred = (pred1 + pred2 + pred3)/3

plt.imshow(pred, cmap = 'gray')
plt.title('pred')
plt.show()

Here, we have used such augmentations ([Affine transformation](https://en.wikipedia.org/wiki/Affine_transformation)) which change the location of the original pixels. But, we can also use [pixel level augmentations](https://github.com/albumentations-team/albumentations#pixel-level-transforms) in TTA, where we don't need to do deaugmentation because such augmentations will not do any affine transformation.

You can find whether you need to do deaugmentation or not, by printing the predicted mask of the augmented test image.

TTA in segmentaion is implemented in this notebook with a clear code:

### [[Inference] PyTorch-TTA-Sub-0.84](https://www.kaggle.com/joshi98kishan/inference-pytorch-tta-sub-0-84)

So, this is it. We have understood the idea.

If you liked it, then please upvote it. :)