This notebook trains a Tensorflow Keras model running on TPU.  Summaries are as follows:

* Model: FPN in [Segmentation Models](https://github.com/qubvel/segmentation_models)
* Backbone: EfficientNet B4
* Image Size: 768
* Learning Rate: maximum 1e-3, cosine decay with warmup
* Epochs: 30 (2 for warmup and the rests are cosine decay)
* Batch Size: 16
* Folds: 5 (trains 1 fold only)
* Loss: 0.5 * BCE + 0.5 * Dice
* Data Augmentations: [Albumentations](https://albumentations.ai/) like
* Oversampling: 3 times for hard samples (DICE coefficients < 0.05)
* 2.5D stride 2

# Reference

Thanks a lot to the authors for sharing the valuable information.

* [UWMGI: UNet Keras [Train] with EDA](https://www.kaggle.com/code/ammarnassanalhajali/uwmgi-unet-keras-train-with-eda)
* [UWMGI: Unet [Train] [PyTorch]](https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch)

The followings are my related notebooks.

* [UWMGI Image Segmentation EDA](https://www.kaggle.com/code/tt195361/uwmgi-image-segmentation-eda)
* [UWMGI Image Segmentation Make TFRecords](https://www.kaggle.com/code/tt195361/uwmgi-image-segmentation-make-tfrecords)
* [UWMGI Image Segmentation Inference](https://www.kaggle.com/code/tt195361/uwmgi-image-segmentation-inference)

# Preparation

In [None]:
!git clone https://github.com/tt195361/TfDataAugmentation.git

import sys
sys.path.append('TfDataAugmentation')

import TfDataAugmentation as Tfda

In [None]:
%env SM_FRAMEWORK=tf.keras
!pip install ../input/segmentation-models-keras/Keras_Applications-1.0.8-py3-none-any.whl --quiet
!pip install ../input/segmentation-models-keras/image_classifiers-1.0.0-py3-none-any.whl --quiet
!pip install ../input/segmentation-models-keras/efficientnet-1.0.0-py3-none-any.whl --quiet
!pip install ../input/segmentation-models-keras/segmentation_models-1.0.1-py3-none-any.whl --quiet

print("Segmentation Models installed.")

In [None]:
DEBUG = False

In [None]:
import numpy as np
import pandas as pd
import segmentation_models as sm
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
from kaggle_datasets import KaggleDatasets
import os
import matplotlib.pyplot as plt
from PIL import Image
import math
import joblib

print(tf.__version__)

In [None]:
SEG_MODEL = sm.FPN
BACKBONE = 'efficientnetb4'
IMAGE_SIZE = 768
BATCH_SIZE = 16
INIT_LR = 1e-4
WARMUP_EPO = 2 if not DEBUG else 1
COSINE_EPO = 28 if not DEBUG else 2
N_EPOCHS = WARMUP_EPO + COSINE_EPO
N_FOLDS = 5
OVERSAMPLE_DICE_THRESHOLD = 0.05
OVERSAMPLE_COUNT = 3
STRIDE_25D = 2

VID = 'V61'
FOLD_I_LIST = [0]
FOLD_I_LIST = FOLD_I_LIST[:2] if DEBUG else FOLD_I_LIST

print("N_EPOCHS:   ", N_EPOCHS)
print("FOLD_I_LIST:", FOLD_I_LIST)

In [None]:
DATA_SRC = 'uwmgi-image-segmentation-tfrecords'
AUTOTUNE = tf.data.experimental.AUTOTUNE

# TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() 
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError: # otherwise detect GPUs
    strategy = tf.distribute.MirroredStrategy() # single-GPU or multi-GPU
    
REPLICAS = strategy.num_replicas_in_sync

print(f"Running on {REPLICAS} replicas")

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path(DATA_SRC)

GCS_DS_PATH

# Dataset

In [None]:
def decode_image(image_bytes, height, width):
    image_raw_bytes = tf.io.decode_raw(image_bytes, out_type=tf.uint8)
    image_len = height * width
    image_bytes = image_raw_bytes[-image_len: ]
    image = tf.reshape(image_bytes, [width, height, 1])
    image = tf.cast(image, dtype=tf.float32) / 255.0
    return image
    
def decode_mask(mask_bytes, height, width):
    mask_png = tf.image.decode_png(mask_bytes)
    # loaded image's shape is [width, height, channel]
    mask_png = tf.reshape(mask_png, [width, height, 3])
    mask_float = tf.cast(mask_png, dtype=tf.float32)
    return mask_float

def resize_image(image):
    resized_image = tf.image.resize(
        image, [IMAGE_SIZE, IMAGE_SIZE],
        method=tf.image.ResizeMethod.BILINEAR)
    return resized_image

In [None]:
def read_tfrecord(example):
    TFREC_FORMAT = {
        'id': tf.io.FixedLenFeature([], tf.string),
        'case_no': tf.io.FixedLenFeature([], tf.int64),
        'day_no': tf.io.FixedLenFeature([], tf.int64),
        'slice_no': tf.io.FixedLenFeature([], tf.int64),
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string),
        'fold': tf.io.FixedLenFeature([], tf.int64),
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'space_h': tf.io.FixedLenFeature([], tf.float32),
        'space_w': tf.io.FixedLenFeature([], tf.float32),
        'large_bowel_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'small_bowel_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'stomach_dice_coef': tf.io.FixedLenFeature([], tf.float32),
        'slice_count': tf.io.FixedLenFeature([], tf.int64),
    }
    
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    sample_id = example['id']
    height = example['height']
    width = example['width']
    image = decode_image(example['image'], height, width)
    mask = decode_mask(example['mask'], height, width)
    fold = example['fold']
    large_bowel_dice_coef = example['large_bowel_dice_coef']
    small_bowel_dice_coef = example['small_bowel_dice_coef']
    stomach_dice_coef = example['stomach_dice_coef']
    slice_no = example['slice_no']
    slice_count = example['slice_count']
    
    resized_image = resize_image(image)
    resized_mask = resize_image(mask)
    return resized_image, \
        (resized_mask, (sample_id, height, width), fold, \
        [large_bowel_dice_coef, small_bowel_dice_coef, stomach_dice_coef], \
        (slice_no, slice_count))

def make_raw_ds(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=None)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=None)
    return dataset

In [None]:
tfrec_file_pattern = os.path.join(GCS_DS_PATH, '*.tfrec')
tfrec_file_names = sorted(tf.io.gfile.glob(tfrec_file_pattern))
raw_ds = make_raw_ds(tfrec_file_names)

print(raw_ds)

In [None]:
# Data in TFRecords are in order.
for ds_item in raw_ds.take(20):
    _, (_, (sample_id, _, _), _, _, (slice_no, slice_count)) = ds_item
    sample_id = sample_id.numpy().decode('utf-8')
    slice_no = slice_no.numpy()
    slice_count = slice_count.numpy()
    print("{0}, {1:3d}, {2:3d}".format(sample_id, slice_no, slice_count))

# 2.5D

In [None]:
def make_image_ds(ds):
    image_ds = ds.map(
        lambda image, rest: image, num_parallel_calls=None)
    return image_ds

def make_skip_ds(ds, n_skips):
    take_ds = ds.take(n_skips)
    skip_ds = ds \
        .skip(n_skips) \
        .concatenate(take_ds)
    return skip_ds

In [None]:
def make_25D_ds(image_m2, image_m1, image_zero_rest, image_p1, image_p2):
    image_zero, rest = image_zero_rest
    mask, size_info, fold, dice_coefs, slice_info = rest
    slice_zero_no, slice_count = slice_info
    image_array = tf.stack(
        [image_m2, image_m1, image_zero, image_p1, image_p2])
    
    def _get_image(slice_no_offset):
        slice_no = slice_zero_no + slice_no_offset
        clipped_slice_no = tf.clip_by_value(slice_no, 1, slice_count)
        slice_index = clipped_slice_no - slice_zero_no + STRIDE_25D
        return image_array[slice_index]
    
    image_m = _get_image(-STRIDE_25D)
    image_p = _get_image(STRIDE_25D)
    image_25D = tf.concat([image_p, image_zero, image_m], axis=-1)
    return image_25D, mask, size_info, fold, dice_coefs

In [None]:
image_m2_ds = make_image_ds(raw_ds)
image_m1_ds = make_skip_ds(make_image_ds(raw_ds), 1)
image_zero_ds = make_skip_ds(raw_ds, 2)
image_p1_ds = make_skip_ds(make_image_ds(raw_ds), 3)
image_p2_ds = make_skip_ds(make_image_ds(raw_ds), 4)

raw_25D_ds = tf.data.Dataset.zip((
    image_m2_ds, image_m1_ds, image_zero_ds, image_p1_ds, image_p2_ds)) \
    .map(make_25D_ds, num_parallel_calls=None)

raw_25D_ds

# Folding and Oversampling

In [None]:
train_data_file_path = os.path.join(GCS_DS_PATH, 'train_data.csv')
train_data_df = pd.read_csv(train_data_file_path)
fold_count_dict = \
    train_data_df['fold'] \
        .value_counts() \
        .sort_index() \
        .to_dict()

fold_count_dict

In [None]:
oversample_count_df = pd.DataFrame()
oversample_count_df['fold'] = train_data_df['fold']
oversample_count_df['oversampled'] = \
    (train_data_df['large_bowel_dice_coef'] <= OVERSAMPLE_DICE_THRESHOLD) \
    | (train_data_df['small_bowel_dice_coef'] <= OVERSAMPLE_DICE_THRESHOLD) \
    | (train_data_df['stomach_dice_coef'] <= OVERSAMPLE_DICE_THRESHOLD)

oversample_count_df

In [None]:
oversample_count_dict = \
    oversample_count_df \
        .groupby('fold') \
        .sum() \
        ['oversampled'] \
        .to_dict()

oversample_count_dict

In [None]:
def get_train_count(fold_i):
    count = 0
    for i in range(N_FOLDS):
        if i != fold_i:
            count += fold_count_dict[i] + oversample_count_dict[i]
    return count

def get_val_count(fold_i):
    return fold_count_dict[fold_i]

In [None]:
def pick_image_mask_dice_coefs(image, mask, info, fold, dice_coefs):
    return image, mask, dice_coefs

def pick_image_mask(image, mask, info, fold, dice_coefs):
    return image, mask

def pick_image_mask_info(image, mask, info, fold, dice_coefs):
    return image, mask, info

def select_train(ds, fold_i):
    ds = ds.filter(lambda image, mask, info, fold, dice_coefs: fold != fold_i)
    return ds
    
def select_val(ds, fold_i):
    ds = ds.filter(lambda image, mask, info, fold, dice_coefs: fold == fold_i)
    return ds

In [None]:
# https://stackoverflow.com/questions/47236465/
# oversampling-functionality-in-tensorflow-dataset-api

def oversample(image, mask, dice_coefs):
    repeat_count = tf.cond(
        tf.math.reduce_any(
            dice_coefs <= OVERSAMPLE_DICE_THRESHOLD),
        lambda: OVERSAMPLE_COUNT,
        lambda: 1)
    repeat_count = tf.cast(repeat_count, dtype=tf.int64)
    oversample_ds = \
        tf.data.Dataset.from_tensors((image, mask)) \
            .repeat(repeat_count)
    return oversample_ds

In [None]:
cut_size = IMAGE_SIZE // 10

transforms = Tfda.Compose([
    Tfda.HorizontalFlip(p=0.5),
    Tfda.VerticalFlip(p=0.5),
    Tfda.RandomBrightnessContrast(
        brightness_limit=0.2, contrast_limit=0.2, p=0.75),
    Tfda.OneOf([
        Tfda.GridDistortion(
            num_steps=10, distort_limit=0.5,
            interpolation='bilinear', border_mode='constant', p=0.5),        
        Tfda.OpticalDistortion(
            distort_limit=1.0, shift_limit=0.05,
            interpolation='bilinear', border_mode='constant', p=0.5),
        ], p=0.75),
    Tfda.ShiftScaleRotate(
        shift_limit=0.125, scale_limit=0.1, rotate_limit=20,
        interpolation='bilinear', border_mode='constant', p=0.75),
    Tfda.Cutout(
        num_holes=8, max_h_size=cut_size, max_w_size=cut_size, p=0.75),
])

def data_augment(image, mask):
    result = transforms(image=image, mask=mask)
    aug_image = result["image"]
    aug_mask = result["mask"]
    return aug_image, aug_mask

In [None]:
def make_datasets(fold_i):
    # When caching, save memory by casting to uint8.
    train_ds = select_train(raw_25D_ds, fold_i) \
        .map(pick_image_mask_dice_coefs, num_parallel_calls=AUTOTUNE) \
        .flat_map(oversample) \
        .cache() \
        .repeat() \
        .shuffle(1024) \
        .map(data_augment, num_parallel_calls=AUTOTUNE) \
        .batch(BATCH_SIZE) \
        .prefetch(AUTOTUNE)
        
    val_ds = select_val(raw_25D_ds, fold_i) \
        .map(pick_image_mask, num_parallel_calls=AUTOTUNE) \
        .batch(BATCH_SIZE) \
        .cache() \
        .prefetch(AUTOTUNE)
    
    train_steps = get_train_count(fold_i) // BATCH_SIZE
    val_steps = get_val_count(fold_i) // BATCH_SIZE

    return train_ds, val_ds, train_steps, val_steps

In [None]:
def make_pred_dataset(fold_i):
    pred_ds = select_val(raw_25D_ds, fold_i) \
        .map(pick_image_mask_info, num_parallel_calls=AUTOTUNE) \
        .batch(BATCH_SIZE) \
        .prefetch(AUTOTUNE)
    return pred_ds

# Visualization

In [None]:
def draw_images_masks(ds):
    rows = 6
    cols = 5
    n_imgs = cols * rows
    plt.figure(figsize=(12, 2.5 * rows))
    for i, (image, mask) in enumerate(ds.take(n_imgs)):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(image, cmap="gray")
        plt.imshow(mask, alpha=0.5)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
# train_ds, val_ds, _, _ = make_datasets(0)

# print("train_ds")
# draw_images_masks(train_ds.unbatch().skip(95))

# print("val_ds")
# draw_images_masks(val_ds.unbatch().skip(80))

# Model

In [None]:
dice_loss_fun = sm.losses.DiceLoss()
bce_loss_fun = sm.losses.BinaryCELoss()

def bce_dice_loss(y_true, y_pred):
    dice_loss = dice_loss_fun(y_true, y_pred)
    bce_loss = bce_loss_fun(y_true, y_pred)
    return 0.5 * dice_loss + 0.5 * bce_loss

In [None]:
# https://www.kaggle.com/code/ammarnassanalhajali/uwmgi-unet-keras-train-with-eda
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    coef = (2. * intersection + smooth) \
        / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
    return coef

In [None]:
def make_model():
    seg_model = SEG_MODEL(
        BACKBONE, encoder_weights='imagenet', 
        classes=3, activation='sigmoid')
    
    inputs = tf.keras.Input(
        shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name="inputs")
    outputs = seg_model(inputs)
    model = tf.keras.Model(
        inputs=inputs, outputs=outputs, name="seg_model")

    # 'steps_per_execution' instructs to send multiple batches to TPU
    # at once. Each core in TPU should receive 128 elements.
    steps_per_execution = 128 // (BATCH_SIZE // REPLICAS)
    print("steps_per_execution: ", steps_per_execution)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(),
        loss=bce_dice_loss,
        metrics=[dice_coef],
        steps_per_execution=steps_per_execution)
    return model

In [None]:
with strategy.scope():
    model = make_model()
    
initial_weights = model.get_weights()
model.summary()

In [None]:
LR_START = INIT_LR
LR_MAX = 1e-3
LR_MIN = 1e-5
LR_RAMPUP_EPOCHS = WARMUP_EPO
LR_SUSTAIN_EPOCHS = 0
EPOCHS = N_EPOCHS

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        decay_total_epochs = EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1
        decay_epoch_index = epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS
        phase = math.pi * decay_epoch_index / decay_total_epochs
        cosine_decay = 0.5 * (1 + math.cos(phase))
        lr = (LR_MAX - LR_MIN) * cosine_decay + LR_MIN
    return lr

rng = [i for i in range(EPOCHS)]
lr_y = [lrfn(x) for x in rng]
plt.figure(figsize=(10, 4))
plt.plot(rng, lr_y)
print("Learning rate schedule: {:.3g} to {:.3g} to {:.3g}". \
      format(lr_y[0], max(lr_y), lr_y[-1]))

In [None]:
def make_callbacks(best_model_file_name):
    cb_monitor = 'val_loss'
    cb_mode = 'min'
    cb_verbose = 1

    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        best_model_file_name, save_best_only=True,
        save_weights_only=False, monitor=cb_monitor, mode=cb_mode,
        verbose=cb_verbose)
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=False)
    
    return [checkpoint, lr_callback]

In [None]:
def fit_one_fold(fold_i, best_model_file_name):
    train_dataset, val_dataset, train_steps, val_steps = make_datasets(fold_i)
    callbacks = make_callbacks(best_model_file_name)

    history = model.fit(
        train_dataset, 
        epochs=EPOCHS,
        verbose=1,
        callbacks=callbacks,
        steps_per_epoch=train_steps,
        validation_data=val_dataset,
        validation_steps=val_steps)
    return history

In [None]:
def plot_history(history, title, labels, subplot):
    plt.subplot(*subplot)
    plt.title(title)
    for label in labels:
        plt.plot(history.history[label], label=label)
    plt.legend()
    
def plot_fit_result(history):
    plt.figure(figsize=(12, 4))
    plot_history(history, "Loss", ['loss', 'val_loss'], (1, 2, 1))
    plot_history(history, "dice_coef", ['dice_coef', 'val_dice_coef'], (1, 2, 2))
    plt.show()

In [None]:
def resize_image_to(image, height, width):
    resized_image = tf.image.resize(
        image, [width, height],
        method=tf.image.ResizeMethod.BILINEAR)
    return resized_image

In [None]:
def encode_pred(one_pred):
    curr_pred = one_pred.flatten()
    
    prev_pred = np.empty_like(curr_pred)
    prev_pred[1:] = curr_pred[:-1]
    prev_pred[0] = 0
    
    next_pred = np.empty_like(curr_pred)
    next_pred[:-1] = curr_pred[1:]
    next_pred[-1] = 0
    
    pixel_no = np.arange(len(curr_pred))
    start_pixels = pixel_no[(prev_pred == 0) & (curr_pred == 1)]
    end_pixels = pixel_no[(curr_pred == 1) & (next_pred == 0)]
    
    encode_list = []
    for start_pixel, end_pixel in zip(start_pixels, end_pixels):
        encode_list.append(str(start_pixel))
        encode_list.append(str(end_pixel - start_pixel + 1))
    
    encoded_pred = ' '.join(encode_list)
    return encoded_pred

In [None]:
def make_predictions(raw_pred, height, width):
    resized_image = resize_image_to(raw_pred, height, width)
    bin_pred = np.where(resized_image >= 0.5, 1, 0)
    large_bowel_pred = encode_pred(bin_pred[:, :, 0])
    small_bowel_pred = encode_pred(bin_pred[:, :, 1])
    stomach_pred = encode_pred(bin_pred[:, :, 2])
    return large_bowel_pred, small_bowel_pred, stomach_pred

In [None]:
def predict_one_fold(fold_i):
    pred_ds = make_pred_dataset(fold_i)
    pred_batch_list = []
    pred_list = []
    for i, (image_batch, mask_batch, info_batch) in enumerate(pred_ds):
        if DEBUG and 3 <= i:
            break
        print('.', end='', flush=True)

        pred_batch = model(image_batch, training=False)
        pred_batch_list.append(pred_batch)
        
        sample_id_batch, height_batch, width_batch = info_batch
        for pred, sample_id, height, width in \
                zip(pred_batch, sample_id_batch, height_batch, width_batch):
            sample_id = sample_id.numpy().decode('utf-8')
            height = height.numpy()
            width = width.numpy()
        
            large_bowel_pred, small_bowel_pred, stomach_pred = \
                make_predictions(pred, height, width)
            pred_list.append([
                sample_id, large_bowel_pred, small_bowel_pred, 
                stomach_pred, fold_i])
    print()
    
    preds = np.concatenate(pred_batch_list, axis=0)
    pred_df = pd.DataFrame(
        pred_list, 
        columns=['id', 'large_bowel', 'small_bowel', 'stomach', 'fold'])
    return preds, pred_df

In [None]:
def save_binary(name, bin_file, file_name_format):
    file_name = file_name_format.format(VID, fold_i)
    joblib.dump(bin_file, file_name)
    print("{0} are saved to {1}.".format(name, file_name))
    
def save_df(name, df, file_name_format):
    file_name = file_name_format.format(VID, fold_i)
    df.to_csv(file_name, index=False)
    print("{0} is saved to {1}.".format(name, file_name))

In [None]:
for fold_i in FOLD_I_LIST:
    print("####################")
    print("# Fold {0}".format(fold_i))
    model.set_weights(initial_weights)
    best_model_file_name = "seg_model_{0}_{1}.hdf5".format(VID, fold_i)
    history = fit_one_fold(fold_i, best_model_file_name)
    plot_fit_result(history)
    
    # model.load_weights(best_model_file_name)
    # preds, pred_df = predict_one_fold(fold_i)
    
    # save_binary("preds", preds, "preds_{0}_{1}.joblib")
    # save_df("pred_df", pred_df, "pred_{0}_{1}.csv")
    print()

In [None]:
!rm -rf TfDataAugmentation