## Keras U-Net training with TFRecords input
In this notebook we will train a U-Net architecture implemented in Keras/TensorFlow. We already created TFRecords of the dataset in this notebook:
  * [HuBMAP TIF 2 JPG+TFRecords](https://www.kaggle.com/mistag/data-hubmap-tif-2-jpg-tfrecords-128-256-512-1024/edit/run/47859494)  
  
U-Net code snippets have been reused from [Polyp Segmentation using UNET in TensorFlow 2.0](https://idiotdeveloper.com/polyp-segmentation-using-unet-in-tensorflow-2/) by Nikhil Tomar.

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, CSVLogger, TensorBoard
from tensorflow.keras import backend as K
import tensorflow as tf
from functools import partial
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
%matplotlib inline  

## Dataset creation
We will use TFRecords with a resolution of 512x512 (downscaled from 1024x1024).

In [None]:
FILENAMES = tf.io.gfile.glob("/kaggle/input/data-hubmap-tif-2-jpg-tfrecords-128-256-512-1024/Cortex-512-*.tfrecord")
split_ind = int(0.75 * len(FILENAMES))
TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]

print("Train TFRecord Files:", len(TRAINING_FILENAMES))
print("Validation TFRecord Files:", len(VALID_FILENAMES))

In [None]:
IMG_SIZE = 512
IMAGE_SIZE = [IMG_SIZE, IMG_SIZE]
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 16

# hyperparameters saved for later use during inference
hparams = {
    "IMG_SIZE": IMG_SIZE,
    "SCALE_FACTOR": 2,
    "BATCH_SIZE": BATCH_SIZE}
with open("hparams.json", "w") as json_file:
    json_file.write(json.dumps(hparams, indent = 4))

# decode image or mask
def decode_image(image, isjpeg=True):
    if isjpeg:
        ch = 3
        image = tf.image.decode_jpeg(image, channels=ch)
    else:
        ch = 1
        image = tf.image.decode_png(image, channels=ch)
        image = tf.expand_dims(image, -1)
    image = tf.cast(image, tf.float32)
    image = image /255.
    image = tf.reshape(image, [*IMAGE_SIZE, ch])
    return image

# read a single record 
def read_tfrecord(example):
    tfrecord_format = ( # only extract features we are interested in
        {
            "image/encoded": tf.io.FixedLenFeature([], tf.string),
            "mask/encoded": tf.io.FixedLenFeature([], tf.string),
        }
    )
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example["image/encoded"], True) # jpeg format
    mask = decode_image(example["mask/encoded"], False) # png format
    return image, mask

# read a single record and do augmentation
def read_tfrecord_tr(example):
    image, mask = read_tfrecord(example)
    # basic augmentation  (expand as desired)
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    if tf.random.uniform(()) > 0.5:
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)
    if tf.random.uniform(()) > 0.75:
        image = tf.image.rot90(image, k=1)
        mask = tf.image.rot90(mask, k=1)
    if tf.random.uniform(()) > 0.75: # random contrast/brightness
        if tf.random.uniform(()) > 0.5:
            a = tf.random.uniform((), 0.7, 1.3)
            image = tf.image.adjust_contrast(image, a)
        else:
            a = tf.random.uniform((), 0., 0.5)
            image = tf.image.adjust_brightness(image, a)
    if tf.random.uniform(()) > 0.8: # add noise
        gnoise = tf.random.normal(shape=tf.shape(image), mean=0.0, stddev=0.1, dtype=tf.float32)
        image = tf.add(image, gnoise)
    if tf.random.uniform(()) > 0.8: # change hue
        a = tf.random.uniform((), -0.2, 0.2)
        image = tf.image.adjust_hue(image, a)  
        
    return image, mask

def load_dataset(filenames, IsTrain=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    if IsTrain: # augmentation
        dataset = dataset.map(
            partial(read_tfrecord_tr), num_parallel_calls=AUTOTUNE
        )
    else: # no augmentation
        dataset = dataset.map(
            partial(read_tfrecord), num_parallel_calls=AUTOTUNE
        )
    # returns a dataset of (image, mask) pairs 
    return dataset

def get_dataset(filenames, IsTrain=True):
    dataset = load_dataset(filenames, IsTrain)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.repeat()
    return dataset

Plot a few images from the dataset to check that everything is OK (including augmentation):

In [None]:
train_dataset = get_dataset(TRAINING_FILENAMES, True)
valid_dataset = get_dataset(VALID_FILENAMES, False)

image_batch, mask_batch = next(iter(train_dataset))

def show_batch(image_batch, mask_batch):
    plt.figure(figsize=(16, 16))
    for n in range(min(BATCH_SIZE,16)):
        ax = plt.subplot(4, 4, n + 1)
        plt.imshow(image_batch[n])
        plt.imshow(np.squeeze(mask_batch[n]), alpha=0.25)#, cmap='binary')
        plt.axis("off")

show_batch(image_batch.numpy(), mask_batch.numpy())

There is a Pandas pickle file accompanying the TFRecords files containing the number of images per TFRecord.

In [None]:
# calculate number of images in train/val sets
df = pd.read_pickle('/kaggle/input/data-hubmap-tif-2-jpg-tfrecords-128-256-512-1024/record_stats.pkl')
tcnt, vcnt = 0, 0
for i in VALID_FILENAMES:
    fname = i.split('/')[-1]
    vcnt += df[df.File == fname].ImgCount.iloc[0]
for i in TRAINING_FILENAMES:
    fname = i.split('/')[-1]
    tcnt += df[df.File == fname].ImgCount.iloc[0]

print('Train images: {}, Validation images: {}'.format(tcnt, vcnt))

# Build U-Net model
The U-Net model is really simple to build, and easy to modify as well.

In [None]:
def conv_block(x, num_filters):
    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    x = Conv2D(num_filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    return x

In [None]:
def build_model():
    size = IMG_SIZE
    num_filters = [12, 24, 48, 96] # adjust filter quantity and sizes as desired
    inputs = Input((size, size, 3))

    skip_x = []
    x = inputs

    ## Encoder
    for f in num_filters:
        x = conv_block(x, f)
        skip_x.append(x)
        x = MaxPool2D((2, 2))(x)

    ## Bridge
    x = conv_block(x, num_filters[-1])

    num_filters.reverse()
    skip_x.reverse()

    ## Decoder
    for i, f in enumerate(num_filters):
        x = UpSampling2D((2, 2))(x)
        xs = skip_x[i]
        x = Concatenate()([x, xs])
        x = conv_block(x, f)

    ## Output
    x = Conv2D(1, (1, 1), padding="same")(x)
    x = Activation("sigmoid")(x)

    return Model(inputs, x)

## Loss Functions
There are several loss functions to choose from, a few ones are defined below. The Focal Tversky loss is known to perform well on many segmentation tasks. We could also use the Dice coefficient loss or the Tversky loss (experiment to find the best one). 

In [None]:
# metrics and loss functions
smooth = 1.

def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef(y_true, y_pred)


def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

def tversky(y_true, y_pred, smooth=1, alpha=0.7):
    y_true_pos = K.flatten(y_true)
    y_pred_pos = K.flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos)
    false_neg = K.sum(y_true_pos * (1 - y_pred_pos))
    false_pos = K.sum((1 - y_true_pos) * y_pred_pos)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)

def focal_tversky_loss(y_true, y_pred, gamma=0.75):
    tv = tversky(y_true, y_pred)
    return K.pow((1 - tv), gamma)

## Compile model
Note that we save the built model as a .json file for use during inference.

In [None]:
lr = 5e-4

model = build_model()
with open("model.json", "w") as json_file:
    json_file.write(model.to_json())

opt = tf.keras.optimizers.Adam(lr)
metrics = ["acc", iou, dice_coef, tversky]
model.compile(loss=focal_tversky_loss, optimizer=opt, metrics=metrics)
model.summary()

## Train model

In [None]:
callbacks = [
          ModelCheckpoint("model.h5", save_best_only=True),
          ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=8, min_lr=0.00001),
          CSVLogger("data.csv"),
          TensorBoard(),
          EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=False)
     ]

In [None]:
%%capture
history = model.fit(train_dataset,
            validation_data=valid_dataset,
            epochs=75,
            steps_per_epoch=tcnt//BATCH_SIZE,
            validation_steps=vcnt//BATCH_SIZE,
            callbacks=callbacks)

In [None]:
model.save('.')

## Inspect learning curves
Below we plot the loss for both training and validation, along with the Dice coefficients. It is important to keep an eye on these curves to verify that our model and training are setup correctly.

In [None]:
plt.figure(figsize=(16, 8))
x = np.arange(1,len(history.history['loss'])+1)
plt.plot(x, history.history['loss'], label='Train loss')
plt.plot(x, history.history['val_loss'], label='Validation loss')
plt.plot(x, history.history['dice_coef'], label='Train Dice coef.')
plt.plot(x, history.history['val_dice_coef'], label='Validation Dice coef.')
plt.xlabel('Epoch')
plt.suptitle('Learning curves')
plt.legend();

Looks pretty good! We are now ready for the final step - making predictions with the saved model. Coming up soon!