This notebook is to train <a href=https://www.kaggle.com/c/sartorius-cell-instance-segmentation>Sartorius Competition</a> as Semantic Segmentation (<span style="color: red; ">Not</span> Instance Segmentation)
- Model : Unet (backboned: EfficientNet, package: segmetation_models-1.0.1)  
- Image : png images -> ndarray -> tfrecord (split into tiles)
- Annotation(target) : tran.csv(run-lenght str) -> adarray -> tfrecord (split into tiles)


Refs.  
https://www.kaggle.com/wrrosa/hubmap-tf-with-tpu-efficientunet-512x512-train  
https://www.kaggle.com/ammarnassanalhajali/sartorius-segmentation-keras-u-net-training

<a class='anchor' id='TOC'></a>
# Table of Contents

1. [Packages](#1)
1. [Accelerator](#2)
1. [Parameters](#3)
1. [Data](#4)
1. [Modeling](#5)
1. [Train](#6)
1. [Evaluation](#7)

<a class='anchor' id='1'></a>
# 1. Packages
[Back to Table of Contents](#TOC)

In [None]:
!pip install segmentation_models==1.0.1 -q

In [None]:
import os, glob, gc, re, yaml, json
from tqdm.notebook import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import cv2
import seaborn as sns
import tensorflow as tf
from datetime import datetime
from pprint import pprint
from PIL import Image, ImageEnhance

import warnings
warnings.filterwarnings('ignore')

<a class='anchor' id='2'></a>
# 2. Accelarator
[Back to Table of Contents](#TOC)</br>
Select processor (priority: TPU>GPU>CPU)

In [None]:
def set_strategy():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() # TPU detection
    except ValueError:
        tpu = None
        gpus = tf.config.experimental.list_logical_devices("GPU")

    if tpu:
        strategy = tf.distribute.TPUStrategy(tpu)
        print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
    elif len(gpus) > 1:
        strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
        print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
    elif len(gpus) == 1:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        print('Running on single GPU ', gpus[0].name)
    else:
        strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
        print('Running on CPU')

    print("Number of accelerators: ", strategy.num_replicas_in_sync)

    return strategy

strategy = set_strategy()

<a class='anchor' id='3'></a>
# 3. Parameters
[Back to Table of Contents](#TOC)

In [None]:
# Path
INPUT_PATH = '../input/sartorius-cell-instance-segmentation/'

# Paramaters
P = {}
P['DEBUG'] = False # If true, number of epochs and fold calculation are minimized.
P['MODEL'] = 'Unet' # UNet, FPN, Linknet, PNPNet
P['BASE_TILE'] = [128, 128] # Tile sile to split original image
P['RESIZED_TILE'] = [128, 128] # Image compressison for quick calculation (If no compression, P['RESIZED_TILE'] = P['BASE_SIZE'])
P['MIN_OVERLAP'] = 32 # Overlap width of each tile (Note: Edge image may overlap more than MIN_OVERLAP)
P['BACKBONE'] = 'efficientnetb1' 
P['WEIGHT'] = '../input/efficientnetb0b7-keras-weights/efficientnet-b1_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5' # Downloaded from https://github.com/Callidior/keras-applications/releases
P['SEED'] = 2021 # Random seed
P['VERBOSE'] = 1 # Show the detail of train history, or not
P['BATCH_COEF'] = 64
P['BATCH_SIZE'] = P['BATCH_COEF'] * strategy.num_replicas_in_sync
P['LR'] = 5e-4 # Learing rate
P['STEPS_COEF'] = 3 # step_per_epoch = P['STEPS_COE'] * #TILES // BATCH_SIZE (Nominal: 1)
P['NFOLDS'] =6 # Number of folds
if P['DEBUG']==True:
    P['EPOCHS'] = 1 # Number of epochs
    P['CALC_FOLDS'] = 1 # One fold calcuration is only performed
elif P['DEBUG']==False:
    P['EPOCHS'] = 5 # Number of epochs
    P['CALC_FOLDS'] = P['NFOLDS'] # Full fold calcuration is performed


AUTO = tf.data.experimental.AUTOTUNE

# 4. Data
[Back to Table of Contents](#TOC)

## Tabular data

In [None]:
df_train = pd.read_csv(INPUT_PATH + 'train.csv')
display(df_train)

In [None]:
# Anotations are grouped by id
df_tmp = df_train.drop_duplicates('id').reset_index(drop=True).sort_values('id')
df_tmp["annotation"] = df_train.groupby('id')['annotation'].agg(list).reset_index(drop=True)
df_train = df_tmp.copy()
display(df_train)

## Feature distributions

In [None]:
def fig_layout(seaborn_plot):
    plt.xlabel('')
    plt.ylabel('')
    plt.yticks([])
    # Hide frame
    for l in ['right', 'top', 'left']:
        seaborn_plot.spines[l].set_visible(False)
    # Count record length
    record_length = 0
    for rectangle in seaborn_plot.patches:
        record_length += rectangle.get_height()
    # Add annotation of ratio 
    for rectangle in seaborn_plot.patches:
        height = rectangle.get_height()
        width = rectangle.get_width()
        ratio = round(height/record_length*100,1)
        # Ratio
        sns_plot.annotate(f'{ratio}%',
                          xy=(rectangle.get_x()+width/2, height),
                          ha='center', va='center', size=8,
                          xytext=(0, 10), textcoords='offset points')


df_tmp = df_train.copy()
# dtype change
# # plate_time: 11h30m00s -> 11.5
hour   = pd.to_datetime(df_train['plate_time'], format='%Hh%Mm%Ss').dt.hour
minute = pd.to_datetime(df_train['plate_time'], format='%Hh%Mm%Ss').dt.minute
df_tmp['plate_time'] = round(hour + minute/60, 2)
# # sample_date: str -> datetime
df_tmp['sample_date'] = pd.to_datetime(df_train['sample_date'], format='%Y-%m-%d').dt.date

# Plot
fig = plt.figure(figsize=(22, 5))
features = ['cell_type', 'plate_time', 'sample_date']
for i, feature in enumerate(features):
    plt.subplot(1,3,i+1)
    plt.title(feature, size=20)
    sns_plot = sns.countplot(x=feature, data=df_tmp)
    fig_layout(sns_plot)
    fig.autofmt_xdate(rotation=90)
    
plt.show()

## Image data

In [None]:
train_imgs = INPUT_PATH + 'train/' + df_train['id'] + '.png'
print(f'train_images: {len(train_imgs)} files')
display(train_imgs)

In [None]:
# 3 cell type images are displayed.
cell_types = df_train['cell_type'].unique()
gs = gridspec.GridSpec(1, 3)
plt.figure(figsize = (25, 20))
for i, cell_type in enumerate(cell_types):
    idx = df_train[df_train['cell_type'] == cell_type].index[0]
    img_id = df_train['id'][idx]
    img = cv2.imread(INPUT_PATH + 'train/' + img_id + '.png')
    ax = plt.subplot(gs[i])
    ax.set_title(f'id: {img_id},  cell_type: {cell_type}')
    ax.imshow(img)
    ax.set_aspect('equal')
    plt.axis('on')   

plt.show()

## Decode annotations (string -> ndarray)

In [None]:
# Decode run-length string -> ndarry
def rle_decode(annotation, shape):
    '''
    annotation: string
    shape: (height, width)
    return: ndarray, mask: 1, background: 0
    '''
    rle = annotation.split() # Even elements are starts, odd elements are the lengths.
    starts  = np.asarray(rle[0:][::2], dtype=int)
    lengths = np.asarray(rle[1:][::2], dtype=int)
    starts -= 1 # Run-length start is numbered from one, on the other hand, list is numbered from zero.
    ends = starts + lengths
    
    mask = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for start, end in zip(starts, ends):
        mask[start:end] = 1

    return mask.reshape(shape)

# Build mask image from all annotations with same id
def build_masks(annotations, shape, distinguish_objects=False):
    '''
    annotation_list: List[string]
    shape: (height, width)
    return: ndarray, mask: integer 1,2,3,..., background:0
    '''
    masks = np.zeros(shape, dtype=np.uint8)
    for i, annotation in enumerate(annotations):
        mask = rle_decode(annotation, shape)
        if distinguish_objects:
            masks = np.where(mask==0, masks, i+1)
        else:
            masks = np.where(mask==0, masks, 1)
    
    return masks

def plot_image_and_mask(img, mask, title=None):
    fig, ax = plt.subplots(1, 4, figsize=(20,4))
    
    ax[0].set_title('Original image')
    ax[0].imshow(img)
    
    ax[1].set_title('High contrasted image')
    img_hc = img.max() - img
    img_hc = np.asarray(ImageEnhance.Contrast(Image.fromarray(img_hc)).enhance(24))
    ax[1].imshow(img_hc)
       
    ax[2].set_title('Mask')
    ax[2].imshow(mask, cmap='inferno')
    
    ax[3].set_title('Image + Mask')
    mask_ = np.tile(np.expand_dims(mask, 2), 3) # shape: (height, width) -> (height, width, 3) 
    mask_ = np.clip(mask_,0,1)*255 # mask: (255,255,255), background: (0,0,0)
    mask_[:,:,2] = 0 # mask: (255,255,0): yellow
    mask_ = mask_.astype(np.uint8) # type: np.uint16 -> np.unit8
    merge_img_mask = cv2.addWeighted(img_hc, 0.80, mask_, 0.20, gamma=0.0)
    ax[3].imshow(merge_img_mask)
    
    fig.suptitle(title, fontsize=14)
    

In [None]:
for cell_type in cell_types:
    idx = df_train[df_train['cell_type'] == cell_type].index[0]
    img_id = df_train['id'][idx]
    img = cv2.imread(INPUT_PATH + 'train/' + img_id + '.png')
    mask = build_masks(annotations=df_train.annotation[idx],
                       shape=(df_train.height[idx], df_train.width[idx]),
                       distinguish_objects=True)
    plot_image_and_mask(img, mask, title=f'id: {img_id},  cell_type: {cell_type}')

## Generate TFRecord

In [None]:
# Utility functions

# Cast datatypes into 1 of the type lists (integer,float and bytes)
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Serialization
def serialize_example(feature0, feature1, feature2):
    # Create a feature dictionary which will be the contents of message
    feature = {'image': _bytes_feature(feature0),
               'mask' : _bytes_feature(feature1),
               'cell_type' : _bytes_feature(feature2)}
    # Serialization: convert the features into to bytes
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

# Count images in a tfrecord file
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(f).group(1)) for f in filenames]
    return np.sum(n)

In [None]:
# Config of tiles
BASE = P['BASE_TILE'][0] # Base tile size (Not original image size)
RESIZE = P['RESIZED_TILE'][0] # Re-sized tile size
reduce = BASE//RESIZE # Reduce base image size

# Path to save tfrecords
TFREC_PATH = f'./tfrec-{len(df_train)}-data_{RESIZE}x{RESIZE}-tile/'
P['DATASET'] = TFREC_PATH
if not os.path.exists(TFREC_PATH):
    os.mkdir(TFREC_PATH)
    os.mkdir(TFREC_PATH + 'train')
    os.mkdir(TFREC_PATH + 'test')

# For statistics
x_tot, x2_tot  = [], [] 

print('Generating tfrcords...')
for idx in range(len(df_train)):
    
    image_id = df_train.loc[idx, 'id']
    # Load image and decode mask
    img = cv2.imread(INPUT_PATH + 'train/' + image_id + '.png')
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    mask = build_masks(df_train.loc[idx,'annotation'], (img.shape[0], img.shape[1])) 
    
    # Padding to make the image dividable by tile size 
    pad0 = (reduce*RESIZE - img.shape[0]%(reduce*RESIZE))%(reduce*RESIZE)
    pad1 = (reduce*RESIZE - img.shape[1]%(reduce*RESIZE))%(reduce*RESIZE)
    img  = np.pad(img,
                  [[pad0//2, pad0-pad0//2], [pad1//2, pad1-pad1//2], [0, 0]],
                  constant_values=0)
    mask = np.pad(mask,
                  [[pad0//2, pad0-pad0//2], [pad1//2,pad1-pad1//2]],
                  constant_values=0)

    # Tiling image and mask using the reshape + transpose trick
    # image
    img = cv2.resize(img, (img.shape[1]//reduce, img.shape[0]//reduce),
                     interpolation=cv2.INTER_AREA)
    img = img.reshape(img.shape[0]//RESIZE, RESIZE, img.shape[1]//RESIZE, RESIZE, 3)
    img = img.transpose(0, 2, 1, 3, 4).reshape(-1, RESIZE, RESIZE, 3)
    # mask
    mask = cv2.resize(mask, (mask.shape[1]//reduce, mask.shape[0]//reduce),
                      interpolation=cv2.INTER_NEAREST)
    mask = mask.reshape(mask.shape[0]//RESIZE, RESIZE, mask.shape[1]//RESIZE, RESIZE)
    mask = mask.transpose(0, 2, 1, 3).reshape(-1, RESIZE, RESIZE)
    # cell_type
    cell_type = df_train.loc[idx, 'cell_type']
    cell_type = [cell_type.encode('utf-8') for _ in range(len(mask))]  
    
    # Generate TFRecord
    num_tiles = 0
    filename = TFREC_PATH + f'train/{image_id}.tfrec'
    with tf.io.TFRecordWriter(filename) as writer:

        for i, (img_, mask_, cell_type_) in enumerate(zip(img, mask, cell_type)):
                       
            x_tot.append((img_/255.0).reshape(-1, 3).mean(0))
            x2_tot.append(((img_/255.0)**2).reshape(-1, 3).mean(0))
            img_ = cv2.cvtColor(img_, cv2.COLOR_RGB2BGR)
            
            # Create tf.example
            example = serialize_example(img_.tobytes(), mask_.tobytes(), cell_type_)
            writer.write(example)
            num_tiles +=1
    
    os.rename(filename, TFREC_PATH + 'train/'+ image_id + '-' + str(num_tiles) + '.tfrec')
    if (idx+1)%100==0:
        print(f'{idx+1} tfrecords generated')

print(f'Total {idx+1} tfrecords generated')
print(f'Saved in {TFREC_PATH}\n')

train_tfrecs = glob.glob(TFREC_PATH + 'train/*.tfrec')
print(f'Number of TFRecord files: {len(train_tfrecs)}')
print(f'Number of total tiles: {count_data_items(train_tfrecs)}')
print('Statistics of pixel values:')
print('Mean:', np.array(x_tot).mean(0))
print('STD :', np.sqrt(np.array(x2_tot).mean(0) - np.array(x_tot).mean(0)**2))    

### Confirm the generated records
Here, confirm 3 tfrecord samples for 3 cell types

In [None]:
from skimage.segmentation import mark_boundaries

DIM = RESIZE
# Dataloader
def _parse_image_function(example_proto):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string),
        'cell_type': tf.io.FixedLenFeature([], tf.string)
    }
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    # image
    image = tf.reshape(tf.io.decode_raw(single_example['image'], out_type=np.dtype('uint8')), (DIM,DIM,3))
    image = tf.image.resize(image, (DIM, DIM))/255.0
    # mask
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'], out_type='bool'),(DIM,DIM,1))
    mask = tf.image.resize(tf.cast(mask,'uint8'),(DIM, DIM))
    # cell_type
    cell_type = single_example['cell_type']
    return image, mask, cell_type

def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(lambda ex: _parse_image_function(ex))
    return dataset

def get_dataset(filename, n):
    dataset = load_dataset(filename)
    dataset = dataset.batch(n)
    return dataset

# Utility functions
def get_cell_type_from_tfrec(tfrec):
    num_tiles = count_data_items([tfrec])
    for i, m, cell_types in get_dataset(tfrec, num_tiles).take(1):
        break
    return cell_types[0].numpy().decode() # string

def get_unique_cell_samples(tfrecs):
    sample_idx = []
    sample_cell_type = []
    for idx in range(100):
        cell_type = get_cell_type_from_tfrec(tfrecs[idx])
        if cell_type not in sample_cell_type:
            sample_idx.append(idx)
            sample_cell_type.append(cell_type)
        if len(sample_idx) == 3:
            break
    return sample_idx, sample_cell_type

# Plot 3 samples
sample_idx, sample_cell_type = get_unique_cell_samples(train_tfrecs)
print(f'Sample index: {sample_idx}')
print(f'Sample cell type: {sample_cell_type}\n')
for idx in sample_idx:
    num_tiles = count_data_items([train_tfrecs[idx]])
    for imgs, masks, cell_types in get_dataset(train_tfrecs[idx], num_tiles).take(1):
        break
    print(f'Sample image: {train_tfrecs[idx].split("/")[-1]}')
    print(f'image shape: {imgs.shape}, mask shape: {masks.shape},  cell type: {cell_types[0].numpy().decode()}')
    
    gs = gridspec.GridSpec(5, num_tiles//5)
    plt.figure(figsize = (12, 10))
    for i in range(num_tiles):
        ax1 = plt.subplot(gs[i])
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        ax1.set_aspect('equal')
        ax1.set_axis_off()
        ax1.imshow(mark_boundaries(imgs[i], masks[i].numpy().squeeze().astype('bool')))
    plt.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.show()

# 

<a class='anchor' id='5'></a>
# 5. Modeling
[Back to Table of Contents](#TOC)

### Build model

In [None]:
# Metric
def iou_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=[1,2,3])
    union = K.sum(y_true,[1,2,3]) + K.sum(y_pred,[1,2,3]) - intersection
    iou = K.mean((intersection + smooth) / (union + smooth), axis=0)
    return iou

# Loss
def bce_dice_loss(y_true, y_pred):
    
    def dice_loss(y_true, y_pred):
        smooth = 1.
        y_true_f = K.flatten(y_true)
        y_pred_f = K.flatten(y_pred)
        intersection = y_true_f * y_pred_f
        score = (2. * K.sum(intersection) + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
        return 1. - score
    
    bce_loss_ = tf.keras.losses.binary_crossentropy(tf.cast(y_true, tf.float32), y_pred)
    dice_loss_ = dice_loss(tf.cast(y_true, tf.float32), y_pred)
    return bce_loss_ * 0.5 + dice_loss_*0.5
 

# Build model
import segmentation_models as sm
model = sm.Unet(P['BACKBONE'], encoder_weights = P['WEIGHT'])
model.compile(optimizer = tf.keras.optimizers.Adam(lr = P['LR']),
              loss = bce_dice_loss,#'focal_tversky',
              metrics=[iou_coef])

print('Model:', P['MODEL'])
print('Model backbone', P['BACKBONE'])
print('Total params: ', model.count_params())
trainable_params = sum(np.prod(w.shape) for w in model.trainable_weights)
print('Trainable params:', trainable_params)
non_trainable_params = sum(np.prod(w.shape) for w in model.non_trainable_weights)
print('Non-trainable params:', non_trainable_params)

### Data parser

In [None]:
DIM = RESIZE
def _parse_image_function_argument(example_proto, augment=True):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string),
    }
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape(tf.io.decode_raw(single_example['image'], out_type=np.dtype('uint8')), (DIM,DIM,3))
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'], out_type='bool'), (DIM,DIM,1))
    
    if augment: # https://www.kaggle.com/kool777/training-hubmap-eda-tf-keras-tpu

        if tf.random.uniform(()) > 0.5:
            image = tf.image.flip_left_right(image)
            mask  = tf.image.flip_left_right(mask)

        if tf.random.uniform(()) > 0.4:
            image = tf.image.flip_up_down(image)
            mask  = tf.image.flip_up_down(mask)

        if tf.random.uniform(()) > 0.5:
            image = tf.image.rot90(image, k=1)
            mask  = tf.image.rot90(mask, k=1)

        if tf.random.uniform(()) > 0.45:
            image = tf.image.random_saturation(image, 0.7, 1.3)

        if tf.random.uniform(()) > 0.45:
            image = tf.image.random_contrast(image, 0.8, 1.2)
    
    return tf.cast(image, tf.float32), tf.cast(mask, tf.float32)

def load_train_dataset(filenames, ordered=False, augment=True):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda ex: _parse_image_function_argument(ex, augment=augment), num_parallel_calls=AUTO)
    return dataset

def get_train_dataset(train_filenames, ordered=True, augment=True, batch_size=P['BATCH_SIZE']):
    dataset = load_train_dataset(train_filenames, ordered=ordered, augment=augment)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(128, seed=P['SEED'])
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_valid_dataset(valid_filenames, ordered=True, augment=False, batch_size=P['BATCH_SIZE']):
    dataset = load_train_dataset(valid_filenames, ordered=ordered, augment=augment)
    dataset = dataset.batch(batch_size, drop_remainder=True)
    #dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset



<a class='anchor' id='6'></a>
# 6. Train
[Back to Table of Contents](#TOC)

In [None]:
from sklearn.model_selection import KFold
from tensorflow.keras import backend as K

# Saved model path
MODEL_PATH = './model/'
if not os.path.exists(MODEL_PATH):
    os.mkdir(MODEL_PATH)

# Saved metrics
metrics = ['loss', 'iou_coef','accuracy']
M = {}
for m in metrics:
    M['train_'+ m] = []
    M['valid_'+ m] = []

# Train
kfold = KFold(n_splits=P['NFOLDS'], shuffle=True, random_state=P['SEED'])
for fold, (train_idx, valid_idx) in enumerate(kfold.split(df_train)):
    
    print('#'*35); print('############ FOLD ',fold+1,' #############'); print('#'*35);
    print(f'Tile Size: {DIM}, Batch Size: {P["BATCH_SIZE"]}')
    
    # Split into train and validation
    train_split = [train_tfrecs[i] for i in train_idx]
    valid_split = [train_tfrecs[i] for i in valid_idx]
    STEPS_PER_EPOCH = P['STEPS_COEF'] * count_data_items(train_split) // P['BATCH_SIZE']
    
    # Build model
    K.clear_session()
    with strategy.scope():   
        model = sm.Unet(P['BACKBONE'], encoder_weights = P['WEIGHT'])
        model.compile(optimizer = tf.keras.optimizers.Adam(lr = P['LR']),
                      loss = bce_dice_loss,
                      metrics = [iou_coef,'accuracy'])
    
    # Callbacks
    checkpoint = tf.keras.callbacks.ModelCheckpoint(
        filepath = MODEL_PATH + f'/model-fold{fold}',
        verbose = P['VERBOSE'],
        monitor ='val_loss',
        patience = 10,
        mode='max',
        save_best_only=True
    )
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_iou_coef',
                                                  mode='max',
                                                  patience=10,
                                                  restore_best_weights=True)
    reduce = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',
                                                  factor=0.1,
                                                  patience=8,
                                                  min_lr=0.00001)
    # Fit
    print(f'Training Model Fold {fold+1}...')
    history = model.fit(
        get_train_dataset(train_split),
        validation_data = get_valid_dataset(valid_split),
        epochs = P['EPOCHS'],
        steps_per_epoch = STEPS_PER_EPOCH,
        callbacks = [checkpoint, reduce, early_stop],
        verbose = P['VERBOSE'],
    )   
    
    # Load best model
    with strategy.scope():
        model = tf.keras.models.load_model(MODEL_PATH + f'model-fold{fold}',
                                           custom_objects = {'iou_coef'     : iou_coef,
                                                             'bce_dice_loss': bce_dice_loss})

    # Save metrics (loss, iou, accuaracy)
    train_metric = model.evaluate(get_valid_dataset(train_split), return_dict=True)
    valid_metric = model.evaluate(get_valid_dataset(valid_split), return_dict=True)
    for m in metrics:
        M['train_'+m].append(train_metric[m])
        M['valid_'+m].append(valid_metric[m])
    
    
    # Plot train result
    plt.figure(figsize=(15,5))
    plt.xlabel('Epoch', size=14)
    plt.ylabel('IoU', size=14)
    ## Epoch-IoU
    epochs = np.arange(len(history.history['iou_coef']))
    plt.plot(epochs, history.history['iou_coef'],    '-o', label='Train IoU', color='black')
    plt.plot(epochs, history.history['val_iou_coef'],'-o', label='Valid IoU', color='red')
    ## Max IoU point
    iou_max = np.max(history.history['val_iou_coef'] )
    x_iou_max = np.argmax(history.history['val_iou_coef'] )
    plt.scatter(x_iou_max, iou_max, s=200, color='red')
    ## Max IoU text
    xdist = plt.xlim()[1] - plt.xlim()[0]
    ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.text(x_iou_max-0.03*xdist, iou_max-0.13*ydist, 'Max IoU\n%.2f'%iou_max, size=12)
    ## IoU Legend
    plt.legend(loc=2)
    ## Epoch-Loss
    plt2 = plt.gca().twinx() # Secondary axis
    plt.ylabel('Loss', size=14)
    plt2.plot(epochs, history.history['loss'],     '-o', label='Train Loss', color='black', linestyle="dashed")
    plt2.plot(epochs, history.history['val_loss'], '-o', label='Valid Loss', color='red',   linestyle="dashed")
    # Loss legend
    plt.legend(loc=3)
    plt.show()
    
    # DEBUG
    if (fold+1) == P['CALC_FOLDS']:
        break

print('Metrics mean over folds')
for m in metrics:
    M['mean_train_'+m] = np.mean(M['train_'+m])
    M['mean_valid_'+m] = np.mean(M['valid_'+m])
    print('Train '+ m + ': '+ str(round(M['mean_train_'+m], 4)))
    print('Valid '+ m + ': '+ str(round(M['mean_valid_'+m], 4)))

In [None]:
# Save parameters
with open(MODEL_PATH + 'params.yaml', 'w') as file:
    yaml.dump(P, file)

# Save metrics
with open(MODEL_PATH + 'metrics.json', 'w') as outfile:
    json.dump(M, outfile)

print(f'Model paramaters and metrics saved in {MODEL_PATH}/params.yaml, metrics.json')

<a class='anchor' id='7'></a>
# 7. Evaluation
[Back to Table of Contents](#TOC)
Here, last fold model and validataion set are used.

In [None]:
# Load model
with strategy.scope():
    model = tf.keras.models.load_model(MODEL_PATH + f'model-fold{fold}', # Use last fold model
                                       custom_objects = {'iou_coef'     : iou_coef,
                                                         'bce_dice_loss': bce_dice_loss})
    print(f'model-fold{fold} is loaded')

In [None]:
def predict(tfrecord, model, threshold=0.5):
    
    # Read tfrecord
    num_tiles = count_data_items([tfrecord])
    dataset = get_valid_dataset(tfrecord, batch_size=num_tiles)
    for imgs, masks in dataset.take(1):
        break
    
    imgs  = imgs/255
    masks = masks.numpy().squeeze().astype('bool')
    preds = model.predict_generator(dataset, verbose=1) # 0 to 1 values
    preds = np.where(preds<threshold, False, True).squeeze() # binarize at threshold
    
    return imgs, masks, preds

# Plot true mask and prediction
sample_idx, sample_cell_type = get_unique_cell_samples(valid_split) # Use last fold validation set

for idx in sample_idx:
    valid_imgs, valid_masks, valid_preds = predict(valid_split[idx], model) 
    image_id = valid_split[idx].split('/')[-1].split('-')[0]
    cell_type = df_train[df_train['id']==image_id]['cell_type'].values[0]
    num_of_objects = len(df_train[df_train['id']==image_id]['annotation'].values[0])
    print(f'ID: {image_id}, Cell type: {cell_type}, Number of objects: {num_of_objects}')
    
    # Show true mask and prediction
    gs = gridspec.GridSpec(5, 2*num_tiles//5) # 5row x 12col
    fig = plt.figure(figsize = (24, 10.5))
    for i in range(num_tiles):
        row = i//6
        col = i%6
        
        # true mask (left 5row x 6col)
        ax1 = plt.subplot(gs[row, col])
        ax1.set_xticklabels([])
        ax1.set_yticklabels([])
        ax1.set_aspect('equal')
        ax1.set_axis_off()
        ax1.imshow(mark_boundaries(valid_imgs[i], valid_masks[i]))
        # prediction (right 5row x 6col)
        ax2 = plt.subplot(gs[row, col+6])
        ax2.set_xticklabels([])
        ax2.set_yticklabels([])
        ax2.set_aspect('equal')
        ax2.set_axis_off()
        ax2.imshow(mark_boundaries(valid_imgs[i], valid_preds[i]))

    fig.suptitle('Ground Truth' + ' '*90 + 'Prediction',  fontsize=25)
    plt.subplots_adjust(wspace=0.01, hspace=0.01)
    plt.tight_layout()
    plt.show()


## Future work
For instance segmentation, cell types of "shsy5y" and "cort" are likely to separate into object, for example using cv2.connectedComponentsWithStats, but "astro" may not be able to. 
*Please upvoke, if useful for you.*