![title](https://camo.githubusercontent.com/51eea85ed59f27be0485cc5774d09b522ea8e77cd3f0753c085cacd18d4a41a0/68747470733a2f2f692e6962622e636f2f4774784753386d2f5365676d656e746174696f6e2d4d6f64656c732d56312d536964652d332d312e706e67)

# Inference with Keras Segmentation Models Library
This notebook will make predictions on the HuBMAP data with a FPN model from the [Segmentation Models library](https://github.com/qubvel/segmentation_models). This library is Keras based and really simple to use. It has four different segmentation models (Unet, Linknet, FPN and PSPNet), and a whopping 25 different pretrained backbones that can be used with each model.  

The training data has been converted into TFRecords in [[data] HuBMAP Image 2 TFRecords 128,256,512,1024](https://www.kaggle.com/mistag/data-hubmap-image-2-tfrecords-128-256-512-1024). The TFRecords are also available in a [dataset](https://www.kaggle.com/mistag/hubmap-tfrecords) (needed for TPU training).   

Training of the model is done in [this notebook](https://www.kaggle.com/mistag/train-fpn-efficientnetb2) (not public yet!), with validation on one image and training on the other 7 in a K-fold cross-validation scheme.


First install a few libraries needed with Segmentation Models.

In [None]:
!pip install ../input/kerasapplications/keras-team-keras-applications-3b180cb -f ./ --no-index -q
!pip install ../input/efficientnet/efficientnet-1.1.0/ -f ./ --no-index -q

In [None]:
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import sys
import numpy as np
import cv2
import glob
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import model_from_json
from tensorflow.keras.utils import CustomObjectScope
from tensorflow.keras import backend as K
from tensorflow.keras.utils import get_custom_objects
import efficientnet as efn
import efficientnet.tfkeras
from skimage import io
import json

## File functions
Input images are really big, and to save memory the files are mapped to disk with numpy.memmap(). A bit slow, but frees a lot of memory.

In [None]:
def read_tif_file(fname):
    img = io.imread(fname)
    img = np.squeeze(img)
    if img.shape[0] == 3: # swap axes as required
        img = img.swapaxes(0,1)
        img = img.swapaxes(1,2)
    return img

# map image to file(s)
def map_img2file(fname):
    img = read_tif_file(fname)
    dims = np.array(img.shape)
    ch = 1 if len(dims) == 2 else dims[2]
    for i in range(ch):
        f = np.memmap('img{}.dat'.format(i), dtype=np.uint8, mode='w+', shape=(dims[0], dims[1]))
        f[:] = img[:,:,i] if ch > 1 else img[:,:]
        del f
    return dims

# read part of image from file
def get_patch_from_file(dims, pos, psize):
    ch = 1 if len(dims) == 2 else dims[2]
    patch = np.zeros([psize[0], psize[1]], dtype=np.uint8) if ch == 1 else np.zeros([psize[0], psize[1], ch], dtype=np.uint8)
    for i in range(ch):
        f = np.memmap('img{}.dat'.format(i), dtype=np.uint8, mode='r', shape=(dims[0], dims[1]))
        p = f[pos[0]:pos[0]+psize[0], pos[1]:pos[1]+psize[1]]
        crop = p.shape
        if ch == 1:
            patch[0:p.shape[0], 0:p.shape[1]] = p
        else:
            patch[0:p.shape[0], 0:p.shape[1],i] = p
        del f
    return patch, crop

Using the RLE-encoder from [this notebook](https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter-sub):

In [None]:
##https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
def rle_encode_less_memory(img):
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

## TTA
Test-time augmentation functions.

In [None]:
def create_TTA_batch(img):
    if len(img.shape) < 4:
        img = np.expand_dims(img, 0)
            
    batch=np.zeros((img.shape[0]*8,img.shape[1],img.shape[2],img.shape[3]), dtype=np.float32)     
    for i in range(img.shape[0]):
        orig = tf.keras.preprocessing.image.img_to_array(img[i,:,:,:])/255. # un-augmented
        batch[i*8,:,:,:] = orig
        batch[i*8+1,:,:,:] = np.rot90(orig, axes=(0, 1), k=1)
        batch[i*8+2,:,:,:] = np.rot90(orig, axes=(0, 1), k=2)
        batch[i*8+3,:,:,:] = np.rot90(orig, axes=(0, 1), k=3)
        orig = orig[:, ::-1]
        batch[i*8+4,:,:,:] = orig
        batch[i*8+5,:,:,:] = np.rot90(orig, axes=(0, 1), k=1)
        batch[i*8+6,:,:,:] = np.rot90(orig, axes=(0, 1), k=2)
        batch[i*8+7,:,:,:] = np.rot90(orig, axes=(0, 1), k=3)
    return batch

def mask_TTA(masks):
    batch=np.zeros((masks.shape[0],masks.shape[1],masks.shape[2],masks.shape[3]), dtype=np.float32)
    for i in range(masks.shape[0]//8):
        batch[i*8,:,:,:] = masks[i*8]
        batch[i*8+1,:,:,:] = np.rot90(masks[i*8+1], axes=(0, 1), k=3)
        batch[i*8+2,:,:,:] = np.rot90(masks[i*8+2], axes=(0, 1), k=2)
        batch[i*8+3,:,:,:] = np.rot90(masks[i*8+3], axes=(0, 1), k=1)
        batch[i*8+4,:,:,:] = masks[i*8+4][:, ::-1]
        batch[i*8+5,:,:,:] = np.rot90(masks[i*8+5], axes=(0, 1), k=3)[:, ::-1]
        batch[i*8+6,:,:,:] = np.rot90(masks[i*8+6], axes=(0, 1), k=2)[:, ::-1]
        batch[i*8+7,:,:,:] = np.rot90(masks[i*8+7], axes=(0, 1), k=1)[:, ::-1]
    return(batch)

# Inference
The images are processed with an overlap of half the patch size, which means that pixels are processed four times. The mask of image afa5e8098 is shifted according to [this discussion](https://www.kaggle.com/c/hubmap-kidney-segmentation/discussion/207517). To make an ensemble of several models, just add them to the MODELS list. The mask threshold must be adjusted when adding more models though.

In [None]:
MEANING_OF_LIFE = 42
MEANING_OF_LIFE_REV = int(str(MEANING_OF_LIFE)[::-1])

PATH = '../input/hubmap-kidney-segmentation/test/'
filelist = glob.glob(PATH+'*.tiff')
SUB_FILE = './submission.csv'
with open(SUB_FILE, 'w') as f:
    f.write("id,predicted\n")
    
MODELS = ['../input/train-fpn-efficientnetb2/FPN+ENetB4-2']
s_th = MEANING_OF_LIFE+K_SPLITS # saturation blanking threshold
p_th = IMG_SIZE*IMG_SIZE//32   # pixel count threshold
size = int(IMG_SIZE * SCALE_FACTOR)

OVERLAP = size//2
STEP = size-OVERLAP

for file in filelist:
    fid = file.replace('\\','.').replace('/','.').split('.')[-2]
    print(fid)
    dims = map_img2file(file)
    pmask = np.zeros(dims[:2], dtype=np.uint8)
    x_shft, y_shft = 0,0
    if fid == 'afa5e8098': # mask correction
        x_shft, y_shft = MEANING_OF_LIFE, MEANING_OF_LIFE_REV
    for modl in range(len(MODELS)):
        print(MODELS[modl])
        # load pre-trained model
        mname = MODELS[modl]
        with open(mname+'.json', 'r') as m:
            lm = m.read()
            model = model_from_json(lm)
        model.load_weights(mname+'.h5')
        # process image
        for x in range((dims[0]-OVERLAP-x_shft)//STEP + min(1,(dims[0]-OVERLAP-x_shft) % STEP)):
            for y in range((dims[1]-OVERLAP-y_shft)//STEP + min(1,(dims[1]-OVERLAP-y_shft) % STEP)):
                tile, crop = get_patch_from_file(dims, [x*STEP+x_shft, y*STEP+y_shft], [size,size])
                # downscale tile
                patch = cv2.resize(tile,
                                   dsize=(IMG_SIZE, IMG_SIZE),
                                   interpolation = cv2.INTER_AREA)
                # simple check of saturation if prediction is worthwhile
                _, s, _ = cv2.split(cv2.cvtColor(patch, cv2.COLOR_BGR2HSV))
                if (s>s_th).sum() > p_th:
                    batch = create_TTA_batch(patch)
                    preds = model.predict(batch)
                    pred = mask_TTA(preds)
                    mask = np.rint(np.sum(pred, axis=0))
                    # upscale tile mask before adding to total mask
                    pint =cv2.resize(mask.astype(int), dsize=(size, size), interpolation = cv2.INTER_NEAREST)
                    pmask[x*STEP:x*STEP+crop[0], y*STEP:y*STEP+crop[1]] += pint[0:crop[0], 0:crop[1]].astype(np.uint8)
    pmask = pmask >= MEANING_OF_LIFE_REV - K_SPLITS
    with open(SUB_FILE, 'a') as f:
        f.write("{},".format(fid))
        f.write(rle_encode_less_memory(pmask))
        f.write("\n")

In [None]:
# clean up intermediate files
%rm -f *.dat