# Inference with Keras U-Net+MobileNetV2 
This is the final notebook in a series of three, the two first being:
  * [[data] HuBMAP TIF 2 JPG+TFRecords 128,256,512,1024](https://www.kaggle.com/mistag/data-hubmap-tif-2-jpg-tfrecords-128-256-512-1024), generating training data
  * [[train] Keras U-Net+MobileNetV2](https://www.kaggle.com/mistag/train-keras-u-net-mobilenetv2/output?scriptVersionId=48294593), training a U-Net model

In [None]:
import numpy as np
import cv2
import glob
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, model_from_json
from tensorflow.keras.utils import CustomObjectScope
from skimage import io
import json
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# get hyperparameters from the training notebook
with open('../input/train-keras-u-net-mobilenetv2/hparams.json') as json_file:
    hparams = json.load(json_file)
hparams

## Helper functions
Input images are downscaled with a scaling factor, and the predicted mask will be upscaled again. The TIF files are really big, and to save memory the images are mapped to disk using numpy.memmap(). A little slower than keeping the whole image in memory, but frees up memory for other things!

In [None]:
SCALE_FACTOR = hparams['SCALE_FACTOR']
IMG_SIZE = hparams['IMG_SIZE']
K_SPLITS = hparams['K_SPLITS'] # number of folds
P_THRESHOLD = 0.5

def read_tif_file(fname):
    img = io.imread(fname)
    img = np.squeeze(img)
    if img.shape[0] == 3: # swap axes as required
        img = img.swapaxes(0,1)
        img = img.swapaxes(1,2)
    return img

# map image to file(s)
def map_img2file(fname):
    img = read_tif_file(fname)
    dims = np.array(img.shape)
    ch = 1 if len(dims) == 2 else dims[2]
    for i in range(ch):
        f = np.memmap('img{}.dat'.format(i), dtype=np.uint8, mode='w+', shape=(dims[0], dims[1]))
        f[:] = img[:,:,i] if ch > 1 else img[:,:]
        del f
    return dims

# read part of image from file
def get_patch_from_file(dims, pos, psize):
    ch = 1 if len(dims) == 2 else dims[2]
    patch = np.zeros([psize[0], psize[1]], dtype=np.uint8) if ch == 1 else np.zeros([psize[0], psize[1], ch], dtype=np.uint8)
    for i in range(ch):
        f = np.memmap('img{}.dat'.format(i), dtype=np.uint8, mode='r', shape=(dims[0], dims[1]))
        p = f[pos[0]:pos[0]+psize[0], pos[1]:pos[1]+psize[1]]
        if ch == 1:
            patch[0:psize[0], 0:psize[1]] = p
        else:
            patch[0:psize[0], 0:psize[1], i] = p
        del f
    return patch

Using the RLE-encoder from [this notebook](https://www.kaggle.com/iafoss/hubmap-pytorch-fast-ai-starter-sub):

In [None]:
##https://www.kaggle.com/bguberfain/memory-aware-rle-encoding
#with bug fix
def rle_encode_less_memory(img):
    #watch out for the bug
    pixels = img.T.flatten()
    
    # This simplified method requires first and last pixel to be zero
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] -= runs[::2]
    
    return ' '.join(str(x) for x in runs)

## Inference time
The models were trained using K-folds technique, and here we create an ensemble of all the trained models, and also throw in some (optional) TTA (Test Time Augmentation).
The predicted masks are big, so we write them to disk immediately after each image - no point making a DataFrame here. 

In [None]:
AUGS = 3 # number of augmentations

def create_TTA_batch(img):
    batch=np.zeros((AUGS,IMG_SIZE,IMG_SIZE,3), dtype=np.float32)
    orig = tf.keras.preprocessing.image.img_to_array(img)/255. 
    batch[0,:,:,:] = orig
    batch[1,:,:,:] = cv2.rotate(orig, cv2.ROTATE_90_CLOCKWISE) 
    batch[2,:,:,:] = cv2.rotate(orig, cv2.ROTATE_90_COUNTERCLOCKWISE)
    return batch

def create_TTA_mask(preds):
    # de-augment mask where needed
    preds[1,:,:] = np.expand_dims(cv2.rotate(preds[1,:,:], cv2.ROTATE_90_COUNTERCLOCKWISE), axis = 2)
    preds[2,:,:] = np.expand_dims(cv2.rotate(preds[2,:,:], cv2.ROTATE_90_CLOCKWISE), axis = 2)
    # sum up
    pred = np.sum(preds, axis=0) / AUGS
    return pred > P_THRESHOLD

In [None]:
PATH = '../input/hubmap-kidney-segmentation/test/'
filelist = glob.glob(PATH+'*.tiff')
# create submission file
SUB_FILE = './submission.csv'
with open(SUB_FILE, 'w') as f:
    f.write("id,predicted\n")
    
size = int(IMG_SIZE * SCALE_FACTOR) # tile size that will be processed

s_th = 45  # saturation threshold
p_th = IMG_SIZE*IMG_SIZE//32 # pixel count threshold
TTA = True

for file in filelist:
    fid = file.replace('\\','.').replace('/','.').split('.')[-2]
    print(fid)
    dims = map_img2file(file)
    pmask = np.zeros(dims[:2], dtype=np.uint8)
    for fold in range(K_SPLITS):
        # load model
        with open('../input/train-keras-u-net-mobilenetv2/model{}.json'.format(fold), 'r') as m:
            lm = m.read()
            model = model_from_json(lm)
        model.load_weights('../input/train-keras-u-net-mobilenetv2/model{}.h5'.format(fold))
        print("Model {}".format(fold))
        # process image
        for x in range(dims[0]//size):
            for y in range(dims[1]//size):
                patch = cv2.resize(get_patch_from_file(np.array(dims),
                                                                [x*size,y*size],
                                                                [size,size]),
                                   dsize=(IMG_SIZE, IMG_SIZE),
                                   interpolation = cv2.INTER_AREA)
                # determine if patch should be predicted or not
                _, s, _ = cv2.split(cv2.cvtColor(patch, cv2.COLOR_BGR2HSV))
                if (s>s_th).sum() > p_th:
                    if TTA:
                        batch = create_TTA_batch(patch)
                        predictions = model.predict(batch)
                        mask = create_TTA_mask(predictions)
                    else:
                        batch = np.array([patch])/255.
                        predictions = model.predict(batch)
                        mask = predictions[0] >= P_THRESHOLD
                    # update total mask
                    pint = cv2.resize(mask.astype(int), dsize=(size, size), interpolation = cv2.INTER_NEAREST) #upsample to original
                    pmask[x*size:(x+1)*size, y*size:(y+1)*size] += pint.astype(np.uint8)
    # save mask to submission file
    pmask = pmask > K_SPLITS/2 # threshold across folds
    with open(SUB_FILE, 'a') as f:
        f.write("{},".format(fid))
        f.write(rle_encode_less_memory(pmask))
        f.write("\n")

In [None]:
%rm -f *.dat