# <div align = 'center'><u> PyTorch Inference with TTA </u></div>

Our ultimate PyTorch pipeline for this competition:
* Training on TPU - **[[FoldTraining] PyTorch-TPUðŸ”¥-8-Cores](https://www.kaggle.com/joshi98kishan/training-pytorch-tpu-8-cores)**
* Inference with TTA (This notebook)

In [None]:
# https://www.kaggle.com/vineeth1999/hubmap-pytorch-efficientunet-offline

!mkdir -p /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/efficientnet_pytorch-0.6.3.xyz /tmp/pip/cache/efficientnet_pytorch-0.6.3.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/segmentation-models-pytorch-0.1.2.xyz /tmp/pip/cache/segmentation_models_pytorch-0.1.2.tar.gz
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.1.20-py3-none-any.whl /tmp/pip/cache/
!cp ../input/segmentationmodelspytorch/segmentation_models/timm-0.2.1-py3-none-any.whl /tmp/pip/cache/
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

In [None]:
import numpy as np
import pathlib
import pandas as pd
import numba, cv2, gc, os, glob
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from albumentations import *

import torch
import torch.nn as nn

from segmentation_models_pytorch.unet import Unet
from segmentation_models_pytorch.encoders import get_preprocessing_fn

import rasterio
from rasterio.windows import Window

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 

import warnings
warnings.filterwarnings('ignore')

print(DEVICE)

I have forked the training nb, and trained the fold models for 60 epochs on each fold.

In [None]:
DATA_PATH = '../input/hubmap-kidney-segmentation'

# path to our training notebook.
PATH_FOLD_MODELS = '../input/test-training-pytorch-tpu-8-cores'

Utilities (Hidden)

In [None]:
@numba.njit()
def rle_numba(pixels):
    size = len(pixels)
    points = []
    if pixels[0] == 1: points.append(0)
    flag = True
    for i in range(1, size):
        if pixels[i] != pixels[i-1]:
            if flag:
                points.append(i+1)
                flag = False
            else:
                points.append(i+1 - points[-1])
                flag = True
    if pixels[-1] == 1: points.append(size-points[-1]+1)    
    return points

def rle_numba_encode(image):
    pixels = image.flatten(order = 'F')
    points = rle_numba(pixels)
    return ' '.join(str(x) for x in points)

def make_grid(shape, window=256, min_overlap=32):
    """
        Return Array of size (N,4), where N - number of tiles,
        2nd axis represente slices: x1,x2,y1,y2 
    """
    x, y = shape
    nx = x // (window - min_overlap) + 1
    x1 = np.linspace(0, x, num=nx, endpoint=False, dtype=np.int64)
    x1[-1] = x - window
    x2 = (x1 + window).clip(0, x)
    ny = y // (window - min_overlap) + 1
    y1 = np.linspace(0, y, num=ny, endpoint=False, dtype=np.int64)
    y1[-1] = y - window
    y2 = (y1 + window).clip(0, y)
    slices = np.zeros((nx,ny, 4), dtype=np.int64)
    
    for i in range(nx):
        for j in range(ny):
            slices[i,j] = x1[i], x2[i], y1[j], y2[j]    
    return slices.reshape(nx*ny,4)

In [None]:
ENCODER_NAME = 'se_resnext50_32x4d'

class HuBMAPModel(nn.Module):
    def __init__(self):
        super(HuBMAPModel, self).__init__()
        self.model = Unet(encoder_name = ENCODER_NAME, 
                          encoder_weights = None,
                          classes = 1,
                          activation = None)
        
    def forward(self, images):
        img_masks = self.model(images)
        return img_masks

In [None]:
fold_models_paths = glob.glob(os.path.join(PATH_FOLD_MODELS, '*.pth'))
fold_models = []

for path in fold_models_paths:
    state_dict = torch.load(path)
    model = HuBMAPModel()
    model.load_state_dict(state_dict)
    model.float()
    model.to(DEVICE)
    model.eval()
    
    fold_models.append(model)

In [None]:
len(fold_models)

In [None]:
preprocess_input = Lambda(image = get_preprocessing_fn(encoder_name = ENCODER_NAME,
                                                       pretrained = 'imagenet'))

identity_trfm = Lambda(image = lambda x,cols=None,rows=None : x)

# Affine transforms
horizontal_flip = HorizontalFlip(p = 1.0)
vertical_flip = VerticalFlip(p = 1.0)
rotate_cw = Rotate(limit = (-90, -90), p = 1.0)
rotate_acw = Rotate(limit = (90, 90), p = 1.0)

# Pixel level transformations
pixel_level_trfms = OneOf([
                    HueSaturationValue(10,15,10),
                    CLAHE(clip_limit=2),
                    RandomBrightnessContrast(),            
                   ], p = 1.0)

# List of augmentations for TTA
tta_augs = [identity_trfm,
            horizontal_flip,
            vertical_flip,
            rotate_cw,
            pixel_level_trfms]

# List of deaugmentations corresponding to the above aug list
tta_deaugs = [None,
              horizontal_flip,
              vertical_flip,
              rotate_acw,
              None]

In [None]:
WINDOW=1024
MIN_OVERLAP=32
NEW_SIZE=256

We will be doing TTA here. 

I have explained TTA in this notebook "[Let's Understand TTA in Segmentation](https://www.kaggle.com/joshi98kishan/let-s-understand-tta-in-segmentation)" in a simplest way possible.

In [None]:
identity = rasterio.Affine(1, 0, 0, 0, 1, 0)
p = pathlib.Path(DATA_PATH)
subm = {}

for i, filename in tqdm(enumerate(p.glob('test/*.tiff')), 
                        total = len(list(p.glob('test/*.tiff')))):
    
    print(f'{i+1} Predicting {filename.stem}')
    
    dataset = rasterio.open(filename.as_posix(), transform = identity)
    slices = make_grid(dataset.shape, window=WINDOW, min_overlap=MIN_OVERLAP)
    preds = np.zeros(dataset.shape, dtype=np.uint8)
    
    for (x1,x2,y1,y2) in slices:
        image = dataset.read([1,2,3],
                    window=Window.from_slices((x1,x2),(y1,y2)))
        image = np.moveaxis(image, 0, -1)
        pred = 0
        
        for fold_model in fold_models:  
            tta_pred = None
            
            for j, tta_aug in enumerate(tta_augs):
                # Augmentation
                aug_img = tta_aug(image = image)['image']
                aug_img = preprocess_input(image = aug_img)['image']
                aug_img = cv2.resize(aug_img, (NEW_SIZE, NEW_SIZE))
                aug_img = np.moveaxis(aug_img, -1, 0)
                aug_img = torch.from_numpy(aug_img)
        
                with torch.no_grad():
                    score = fold_model(aug_img.float().to(DEVICE)[None])
                    score = score.cpu().numpy()[0][0]
                    
                    # Deaugmentation
                    if tta_deaugs[j] is not None:
                        score = tta_deaugs[j](image = image, 
                                              mask = score)['mask']

                    score = cv2.resize(score, (WINDOW, WINDOW))            

                    if tta_pred is None:
                        tta_pred = score
                    else:       
                        tta_pred += score
             
            tta_pred = tta_pred / len(tta_augs) 
            pred += tta_pred
            
        pred = pred / len(fold_models)
        preds[x1:x2,y1:y2] = (pred > 0).astype(np.uint8)
            
    subm[i] = {'id':filename.stem, 'predicted': rle_numba_encode(preds)}
    del preds
    gc.collect();

In [None]:
submission = pd.DataFrame.from_dict(subm, orient='index')
submission.to_csv('submission.csv', index=False)

submission.head()