## Keras UFPN+EfficientNet training with TFRecords input
In this notebook we will train a U-Net architecture implemented in Keras/TensorFlow. We already created TFRecords of the dataset in this notebook:
  * [HuBMAP TIF 2 TFRecords](https://www.kaggle.com/mistag/hubmap-image-2-tfrecords-256-512-1024)  
  


In [None]:
!pip install segmentation_models -q

In [None]:
import numpy as np
import pandas as pd
from pylab import *
import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import sys
import segmentation_models as sm
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.models import model_from_json
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
import tensorflow_addons as tfa
from functools import partial
import json
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
%matplotlib inline  

In [None]:
DPATH = '../input/fork-of-data-hubmap-image-2-tfrecords-256-512-10'
# get parameters from the data creation notebook
#with open(DPATH+'/dparams.json') as json_file:
#    dparams = json.load(json_file)
#dparams

## Dataset creation
We will use TFRecords with a resolution of 256x256 (downscaled from 1024x1024).

In [None]:
FILENAMES = tf.io.gfile.glob(DPATH+"/*-256.tfrecord")
K_SPLITS = 5 # number of folds

In [None]:
# augmentation
def data_augment(image, mask):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_cutout = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # grayscale
    if tf.random.uniform(()) > 0.9: 
        image = tf.image.rgb_to_grayscale(image)
        image = tf.image.grayscale_to_rgb(image)

    # x*90deg rotation
    if p_rotate > .5:
        image, mask = data_augment_rotate(image, mask)
    
    # flip
    image, mask = data_augment_spatial(image, mask)
    
    # brightness/contrast
    if tf.random.uniform(()) > 0.5:
        if tf.random.uniform(()) > 0.5:
            image = tf.image.random_contrast(image, 0.8, 1.2)
        else:
            image = tf.image.random_brightness(image, 0.1)
    
    # hue/saturation
    if tf.random.uniform(()) > 0.5:
        if tf.random.uniform(()) > 0.5:    
            image = tf.image.random_saturation(image, 0.7, 1.3)
        else:      
            a = tf.random.uniform((), -0.1, 0.1)
            image = tf.image.adjust_hue(image, a)
            
    # noise 
    if tf.random.uniform(()) > 0.5: 
        if tf.random.uniform(()) > 0.5:
            gnoise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.1, dtype=tf.float32)
            image = tf.add(image, gnoise)
        else:
            image = tf.image.random_jpeg_quality(image, 40, 80, seed=None)
    
    # black pixels
    if tf.random.uniform(()) > 0.75: 
        r = tf.random.uniform((IMG_SIZE, IMG_SIZE,3), minval=0, maxval=1, dtype=tf.dtypes.float32)      
        r = (r > tf.random.uniform([], 0., .2, dtype=tf.float32))
        image = tf.math.multiply(image, tf.cast(r, dtype=tf.float32))
        
    # cutout    
    if tf.random.uniform(()) > 0.75:
        image = data_augment_cutout(image, 10, 10, 48)
            
    image = tf.clip_by_value(image, 0.0, 1.0)

    
    return image, mask

def one_cut(image, min_size, max_size):
    image = tf.squeeze(tfa.image.random_cutout(tf.raw_ops.Pack(values=[image]),
                                               (tf.random.uniform((), minval=min_size, maxval=max_size, dtype=tf.dtypes.int32),
                                                tf.random.uniform((), minval=min_size, maxval=max_size, dtype=tf.dtypes.int32)),
                                               #tf.random.uniform((), minval=min_size, maxval=max_size, dtype=tf.dtypes.int32),
                                               constant_values=tf.random.uniform(())))
    return image

def data_augment_cutout(image, max_cuts, min_size, max_size):
    cuts = tf.random.uniform((), minval=1, maxval=max_cuts, dtype=tf.dtypes.int32)
    image = one_cut(image, min_size, max_size)
    if cuts > 1:
        image = one_cut(image, min_size, max_size)
    if cuts > 2:
        image = one_cut(image, min_size, max_size)
    if cuts > 3:
        image = one_cut(image, min_size, max_size)
    if cuts > 4:
        image = one_cut(image, min_size, max_size)
    if cuts > 5:
        image = one_cut(image, min_size, max_size)
    if cuts > 6:
        image = one_cut(image, min_size, max_size)
    if cuts > 7:
        image = one_cut(image, min_size, max_size)
    if cuts > 8:
        image = one_cut(image, min_size, max_size)
    if cuts > 9:
        image = one_cut(image, min_size, max_size)
    return image

def data_augment_spatial(image, mask):
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)

    return image, mask

def data_augment_rotate(image, mask):
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if p_rotate > .66:
        image = tf.image.rot90(image, k=3) # rotate 270º
        mask = tf.image.rot90(mask, k=3) # rotate 270º
    elif p_rotate > .33:
        image = tf.image.rot90(image, k=2) # rotate 180º
        mask = tf.image.rot90(mask, k=2) # rotate 180º
    else:
        image = tf.image.rot90(image, k=1) # rotate 90º
        mask = tf.image.rot90(mask, k=1) # rotate 90º

    return image, mask

def data_augment_crop(image, mask):
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    crop_size = tf.random.uniform([], int(IMG_SIZE*.8), IMG_SIZE, dtype=tf.int32)
    
    if p_crop > .5:
        ox = tf.random.uniform([], 0, IMG_SIZE-crop_size, dtype=tf.int32)
        oy = tf.random.uniform([], 0, IMG_SIZE-crop_size, dtype=tf.int32)
        image = tf.image.crop_to_bounding_box(image, oy, ox, crop_size, crop_size)
        mask = tf.image.crop_to_bounding_box(mask, oy, ox, crop_size, crop_size)
    else:
        if p_crop > .4:
            image = tf.image.central_crop(image, central_fraction=.7)
            mask = tf.image.central_crop(mask, central_fraction=.7)
        elif p_crop > .2:
            image = tf.image.central_crop(image, central_fraction=.8)
            mask = tf.image.central_crop(mask, central_fraction=.8)
        else:
            image = tf.image.central_crop(image, central_fraction=.9)
            mask = tf.image.central_crop(mask, central_fraction=.9)
    
    image = tf.image.resize(image, size=[IMG_SIZE, IMG_SIZE])
    mask = tf.image.resize(mask, size=[IMG_SIZE, IMG_SIZE], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return image, mask

In [None]:
IMG_SIZE = 256
IMAGE_SIZE = [IMG_SIZE, IMG_SIZE]
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16

# hyperparameters saved for later use during inference
hparams = {
    "IMG_SIZE": IMG_SIZE,
    "SCALE_FACTOR": 1024//IMG_SIZE,  # TFRecords dataset uses tiles of 1024x1024
    "BATCH_SIZE": BATCH_SIZE,
    "K_SPLITS": K_SPLITS}
with open("hparams.json", "w") as json_file:
    json_file.write(json.dumps(hparams, indent = 4))

# decode image or mask
def decode_image(image, isjpeg=True):
    if isjpeg:
        ch = 3
        image = tf.image.decode_jpeg(image, channels=ch)
    else:
        ch = 1
        image = tf.image.decode_png(image, channels=ch)
        image = tf.expand_dims(image, -1)
    image = tf.cast(image, tf.float32)
    image = image /255.
    image = tf.reshape(image, [*IMAGE_SIZE, ch])
    return image

# read a single record 
def read_tfrecord(example):
    tfrecord_format = ( # only extract features we are interested in
        {
            "image/encoded": tf.io.FixedLenFeature([], tf.string),
            "mask/encoded": tf.io.FixedLenFeature([], tf.string),
        }
    )
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example["image/encoded"], True) # jpeg format
    mask = decode_image(example["mask/encoded"], False) # png format
    return image, mask

def load_dataset(filenames, IsTrain=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    # automatically interleaves reads from multiple files
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads = AUTOTUNE)
    # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.with_options(ignore_order)  
    dataset = dataset.map(partial(read_tfrecord), num_parallel_calls=AUTOTUNE)
    if IsTrain: # augmentation        
        dataset = dataset.map(partial(data_augment), num_parallel_calls=AUTOTUNE)
    # returns a dataset of (image, mask) pairs 
    return dataset

def get_dataset(filenames, IsTrain=True):
    dataset = load_dataset(filenames, IsTrain)
    dataset = dataset.shuffle(BATCH_SIZE*256)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.repeat()
    return dataset

Plot a few images from the dataset to check that everything is OK (including augmentation):

In [None]:
train_dataset = get_dataset(FILENAMES, True)

image_batch, mask_batch = next(iter(train_dataset))

def show_batch(image_batch, mask_batch):
    plt.figure(figsize=(16, 16))
    for n in range(min(BATCH_SIZE,16)):
        ax = plt.subplot(4, 4, n + 1)
        plt.imshow(image_batch[n])
        plt.imshow(np.squeeze(mask_batch[n]), alpha=0.25)
        plt.axis("off")

show_batch(image_batch.numpy(), mask_batch.numpy())

There is a Pandas pickle file accompanying the TFRecords files containing the number of images per TFRecord.

In [None]:
# calculate number of images in train/val sets for each fold
kf = KFold(n_splits=K_SPLITS)
df = pd.read_pickle(DPATH+'/record_stats.pkl')
tcnt, vcnt = np.zeros(K_SPLITS, dtype=int), np.zeros(K_SPLITS, dtype=int)
idx = 0
for train_index, test_index in kf.split(FILENAMES):
    for i in train_index:
        fname = FILENAMES[i].split('/')[-1].split('\\')[-1]
        tcnt[idx] += df[df.File == fname].ImgCount.iloc[0]
    for i in test_index:
        fname = FILENAMES[i].split('/')[-1].split('\\')[-1]
        vcnt[idx] += df[df.File == fname].ImgCount.iloc[0]
    idx += 1

print('Train images: {}, Validation images: {}'.format(tcnt, vcnt))

Let's have a look at some statistics from the [quick reference](https://www.kaggle.com/mistag/hubmap-quick-reference):

In [None]:
dff = pd.read_pickle('../input/hubmap-quick-reference/image_stats.pkl')
tot = dff[dff.dataset == 'train'].glomeruli.sum()
dff.head(15)

## Loss Functions
There are several loss functions to choose from, a few ones are defined below. The Focal Tversky loss is known to perform well on many segmentation tasks. We could also use the Dice coefficient loss or the Tversky loss (experiment to find the best one). 

In [None]:
# metrics and loss functions
smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef(y_true, y_pred)


def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

def tversky(y_true, y_pred, smooth=1, alpha=0.7):
    y_true_pos = K.flatten(y_true)
    y_pred_pos = K.flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos)
    false_neg = K.sum(y_true_pos * (1 - y_pred_pos))
    false_pos = K.sum((1 - y_true_pos) * y_pred_pos)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)

def focal_tversky_loss(y_true, y_pred, gamma=0.75):
    tv = tversky(y_true, y_pred)
    return K.pow((1 - tv), gamma)

## Compile & train model
Note that we save the built models as a .json file for use during inference, one for each fold.

In [None]:
MNAME = 'FPN-model43e'

def get_callbacks(idx):
    mc = ModelCheckpoint(MNAME+"-{}.h5".format(idx), save_best_only=True)
    rp = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)
    #cl = CSVLogger("train_log-{}.csv".format(idx))
    #tb = TensorBoard()
    es = EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=False)
    return [mc, rp, es]

In [None]:
#%%capture
history = []
loss, vloss, dice, vdice = np.zeros(K_SPLITS),np.zeros(K_SPLITS),np.zeros(K_SPLITS),np.zeros(K_SPLITS)
idx = 0
lr = 5e-4
selected_folds = [4] # select a few folds only to reduce execution time
for train_index, test_index in kf.split(FILENAMES):
    if idx in selected_folds: 
        train_dataset = get_dataset([FILENAMES[i] for i in train_index], True)
        valid_dataset = get_dataset([FILENAMES[i] for i in test_index], False)
        model = sm.FPN('efficientnetb4', classes=1, encoder_weights='imagenet', activation = 'sigmoid')
        with open(MNAME+"-{}.json".format(idx), "w") as json_file:
            json_file.write(model.to_json())    
        opt = tf.keras.optimizers.Adam(lr)
        metrics = ["acc", iou, dice_coef, tversky]
        model.compile(loss=dice_coef_loss, optimizer=opt, metrics=metrics)
        callbacks = get_callbacks(idx)
        hist = model.fit(train_dataset, 
                            validation_data=valid_dataset,
                            epochs=25,
                            steps_per_epoch=1+tcnt[idx]//BATCH_SIZE,
                            validation_steps=1+vcnt[idx]//BATCH_SIZE,
                            callbacks=callbacks)
        loss[idx] = hist.history['loss'][-1]
        vloss[idx] = min(hist.history['val_loss'])
        dice[idx] = hist.history['dice_coef'][-1]
        vdice[idx] = max(hist.history['val_dice_coef'])
        history.append(hist)
    idx += 1

In [None]:
pd.DataFrame({'loss': loss, 'min val.loss': vloss, 'dice coef.': dice, 'max val. dice coef.': vdice})

## Inspect learning curves
Below we plot the loss for both training and validation along with the Dice coefficients for the last fold. It is important to keep an eye on these curves to verify that our model and training are setup correctly.

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
plt.figure(figsize=(16,10))
for i in range(len(selected_folds)):
    plt.plot(history[i].history['loss'], linestyle='-', color=colors[i], label='Train loss fold #{}'.format(selected_folds[i]))
for i in range(len(selected_folds)):
    plt.plot(history[i].history['val_loss'], linestyle='--', color=colors[i], label='Validation loss fold #{}'.format(selected_folds[i]))
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show();

In [None]:
colors = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple']
plt.figure(figsize=(16,10))
for i in range(len(selected_folds)):
    plt.plot(history[i].history['dice_coef'], linestyle='-', color=colors[i], label='Train dice coef. fold #{}'.format(selected_folds[i]))
for i in range(len(selected_folds)):
    plt.plot(history[i].history['val_dice_coef'], linestyle='--', color=colors[i], label='Validation dice coef. fold #{}'.format(selected_folds[i]))
plt.title('Model Dice Coef.')
plt.ylabel('Dice coef.')
plt.xlabel('Epoch')
plt.legend()
plt.show();

Looks pretty good! We are now ready for the final step - [making predictions with the saved model](https://www.kaggle.com/mistag/inference-hubmap-fpn-single-model-ii). 