![title](https://camo.githubusercontent.com/51eea85ed59f27be0485cc5774d09b522ea8e77cd3f0753c085cacd18d4a41a0/68747470733a2f2f692e6962622e636f2f4774784753386d2f5365676d656e746174696f6e2d4d6f64656c732d56312d536964652d332d312e706e67)

# Inference with Keras Segmentation Models Library
This notebook will make predictions on the HuBMAP data with a FPN model from the [Segmentation Models library](https://github.com/qubvel/segmentation_models). This library is Keras based and really simple to use. It has four different segmentation models (Unet, Linknet, FPN and PSPNet), and a whopping 25 different pretrained backbones that can be used with each model.  

The training data has been converted into TFRecords in [HuBMAP Image 2 TFRecords 256,512,1024](https://www.kaggle.com/mistag/hubmap-image-2-tfrecords-256-512-1024). 

Training of the model is done in [this notebook](https://www.kaggle.com/mistag/train-fpn-segmentation-model-no-43e), with validation on three images and training on the other 12 in a K-fold cross-validation scheme.

The Feature Pyramid Model was used in training:
![FPN](https://raw.githubusercontent.com/qubvel/segmentation_models/master/images/fpn.png)


First install a few libraries needed with Segmentation Models.

In [None]:
!pip install ../input/kerasapplications/keras-team-keras-applications-3b180cb -f ./ --no-index -q
!pip install ../input/efficientnet/efficientnet-1.1.0/ -f ./ --no-index -q

In [None]:
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import sys
import numpy as np
import cv2
import glob
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import model_from_json
from tensorflow.keras.utils import CustomObjectScope
from tensorflow.keras import backend as K
from tensorflow.keras.utils import get_custom_objects
import efficientnet as efn
import efficientnet.tfkeras
import rasterio
from rasterio.windows import Window
import gdal
import json
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# get hyperparameters from the training notebook
with open('../input/train-fpn-segmentation-model-no-43e/hparams.json') as json_file:
    hparams = json.load(json_file)
hparams

In [None]:
IMG_SIZE = hparams['IMG_SIZE']
SCALE_FACTOR = hparams['SCALE_FACTOR']
K_SPLITS = hparams['K_SPLITS']
IDNT = rasterio.Affine(1, 0, 0, 0, 1, 0)

Using the RLE-encoder from [this notebook](https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter-sub):

In [None]:
##https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
def rle_encode_less_memory(img):
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

## TTA
Test-time augmentation functions.

In [None]:
def create_TTA_batch(img):
    if len(img.shape) < 4:
        img = np.expand_dims(img, 0)
            
    batch=np.zeros((img.shape[0]*8,img.shape[1],img.shape[2],img.shape[3]), dtype=np.float32)     
    for i in range(img.shape[0]):
        orig = tf.keras.preprocessing.image.img_to_array(img[i,:,:,:])/255. # un-augmented
        batch[i*8,:,:,:] = orig
        batch[i*8+1,:,:,:] = np.rot90(orig, axes=(0, 1), k=1)
        batch[i*8+2,:,:,:] = np.rot90(orig, axes=(0, 1), k=2)
        batch[i*8+3,:,:,:] = np.rot90(orig, axes=(0, 1), k=3)
        orig = orig[:, ::-1]
        batch[i*8+4,:,:,:] = orig
        batch[i*8+5,:,:,:] = np.rot90(orig, axes=(0, 1), k=1)
        batch[i*8+6,:,:,:] = np.rot90(orig, axes=(0, 1), k=2)
        batch[i*8+7,:,:,:] = np.rot90(orig, axes=(0, 1), k=3)
    return batch

def mask_TTA(masks):
    batch=np.zeros((masks.shape[0],masks.shape[1],masks.shape[2],masks.shape[3]), dtype=np.float32)
    for i in range(masks.shape[0]//8):
        batch[i*8,:,:,:] = masks[i*8]
        batch[i*8+1,:,:,:] = np.rot90(masks[i*8+1], axes=(0, 1), k=3)
        batch[i*8+2,:,:,:] = np.rot90(masks[i*8+2], axes=(0, 1), k=2)
        batch[i*8+3,:,:,:] = np.rot90(masks[i*8+3], axes=(0, 1), k=1)
        batch[i*8+4,:,:,:] = masks[i*8+4][:, ::-1]
        batch[i*8+5,:,:,:] = np.rot90(masks[i*8+5], axes=(0, 1), k=3)[:, ::-1]
        batch[i*8+6,:,:,:] = np.rot90(masks[i*8+6], axes=(0, 1), k=2)[:, ::-1]
        batch[i*8+7,:,:,:] = np.rot90(masks[i*8+7], axes=(0, 1), k=1)[:, ::-1]
    return(batch)

# Inference
The images are processed with an overlap of half the patch size, which means that pixels are processed four times. To make an ensemble of several models, just add them to the MODELS list. The mask threshold must be adjusted when adding more models though.

In [None]:
MEANING_OF_LIFE = 42

PATH = '../input/hubmap-kidney-segmentation/test/'
filelist = glob.glob(PATH+'*.tiff')
if len(filelist) == 5: # save time durimg commit phase by only processing one image
    filelist = filelist[:1]
    COMMIT = True
else:
    COMMIT = False

SUB_FILE = './submission.csv'
with open(SUB_FILE, 'w') as f:
    f.write("id,predicted\n")
    
MODELS = ['../input/train-fpn-segmentation-model-no-43e/FPN-model43e-4']
s_th = MEANING_OF_LIFE+K_SPLITS # saturation blanking threshold
p_th = IMG_SIZE*IMG_SIZE//32   # pixel count threshold
size = int(IMG_SIZE * SCALE_FACTOR)

OVERLAP = size//2
STEP = size-OVERLAP

for file in filelist:
    fid = file.replace('\\','.').replace('/','.').split('.')[-2]
    img_data = rasterio.open(file, transform=IDNT)
    if img_data.count != 3: # channels as subdata
        layers = [rasterio.open(subd) for subd in img_data.subdatasets]
    dims = [img_data.shape[0], img_data.shape[1]]
    pmask = np.zeros(dims[:2], dtype=np.uint8)
    for modl in MODELS:
        print('image: {}, model: {}'.format(fid, modl))
        # load pre-trained model
        with open(modl+'.json', 'r') as m:
            lm = m.read()
            model = model_from_json(lm)
        model.load_weights(modl+'.h5')
        # process image
        for x in tqdm(range((dims[0]-OVERLAP)//STEP + min(1,(dims[0]-OVERLAP) % STEP))):
            for y in range((dims[1]-OVERLAP)//STEP + min(1,(dims[1]-OVERLAP) % STEP)):
                x1, x2, y1, y2 = x*STEP, x*STEP+size, y*STEP, y*STEP+size
                crop = [size, size]
                if x2 > dims[0]:
                    crop[0] = dims[0] - x*STEP
                if y2 > dims[1]:
                    crop[1] = dims[1] - y*STEP
                
                if img_data.count == 3: # normal
                    tile = img_data.read([1, 2, 3], window=Window.from_slices((x1, x2), (y1, y2)))
                    tile = np.moveaxis(tile, 0, -1)
                else: # with subdatasets/layers
                    tile = np.zeros((crop[0], crop[1], 3), dtype=np.uint8)
                    for fl in range(3):
                        tile[:, :, fl] = layers[fl].read(window=Window.from_slices((x1, x2), (y1, y2)))
                if crop != (size,size):
                    impad = np.zeros((size,size,3), dtype=np.uint8)
                    impad[:tile.shape[0],:tile.shape[1]] = tile
                    tile = impad
                # downscale tile
                patch = cv2.resize(tile,
                                   dsize=(IMG_SIZE, IMG_SIZE),
                                   interpolation = cv2.INTER_AREA)
                # simple check of saturation if prediction is worthwhile
                _, s, _ = cv2.split(cv2.cvtColor(patch, cv2.COLOR_RGB2HSV))
                if (s>s_th).sum() > p_th:
                    batch = create_TTA_batch(patch)
                    preds = model.predict(batch)
                    pred = mask_TTA(preds)
                    mask = np.rint(np.sum(pred, axis=0))
                    # upscale tile mask before adding to total mask
                    pint =cv2.resize(mask.astype(int), dsize=(size, size), interpolation = cv2.INTER_NEAREST)
                    pmask[x*STEP:x*STEP+crop[0], y*STEP:y*STEP+crop[1]] += pint[0:crop[0], 0:crop[1]].astype(np.uint8)
    pmask = pmask >= 16
    with open(SUB_FILE, 'a') as f:
        f.write("{},".format(fid))
        f.write(rle_encode_less_memory(pmask))
        f.write("\n")

Reality check: Plot the last pmask.

In [None]:
if COMMIT:
    plt.figure(figsize=(20,20))
    plt.imshow(pmask);