## Tensorflow HuBMAP - Hacking the Kidney competition starter kit:



# Versions

V1: Base kernal copied From [this](https://www.kaggle.com/wrrosa/hubmap-tf-with-tpu-efficientunet-512x512-train) kernel.


# Refferences:
* @marcosnovaes  https://www.kaggle.com/marcosnovaes/hubmap-looking-at-tfrecords and https://www.kaggle.com/marcosnovaes/hubmap-unet-keras-model-fit-with-tpu
* @mgornergoogle https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu
* @qubvel https://github.com/qubvel/segmentation_models  !! 25 available backbones for each of 4 architectures
* @kool777, @joshi98kishan https://www.kaggle.com/kool777/training-hubmap-eda-tf-keras-tpu
* @cdeotte https://www.kaggle.com/cdeotte/triple-stratified-kfold-with-tfrecords


# Init - parameters, packages, gcs_paths, tpu

In [None]:
P = {}
P['EPOCHS'] = 60
P['BACKBONE'] = 'efficientnetb0' 
P['NFOLDS'] = 5
P['SEED'] = 2021
P['VERBOSE'] = 1
P['DISPLAY_PLOT'] = True 
P['BATCH_COE'] = 24 # BATCH_SIZE = P['BATCH_COE'] * strategy.num_replicas_in_sync

P['TILING'] = [1024, 512] # 1024,512 1024,256 1024,128 1536,512 768,384
P['DIM'] = P['TILING'][1] 
P['DIM_FROM'] = P['TILING'][0]

P['LR'] = 3e-4
P['OVERLAPP'] = False
P['STEPS_COE'] = 1
P['FOLDS'] = [0,1,2,3,4]
P['smoothing'] = 1e-7
P['WANDB'] = False
P['SOFT_PROB'] = 0.7


import yaml
with open(r'params.yaml', 'w') as file:
    yaml.dump(P, file)

In [None]:
! pip install segmentation_models -q
# ! pip install -q tensorflow-io
# import tensorflow_io as tfio
%matplotlib inline

import os
os.environ['SM_FRAMEWORK'] = 'tf.keras'
import glob
import math
import random
import segmentation_models as sm
from segmentation_models.losses import bce_jaccard_loss

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.model_selection import KFold

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.utils import get_custom_objects

# from tensorflow.keras import mixed_precision
# mixed_precision.set_global_policy('mixed_float16')

# tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

from kaggle_datasets import KaggleDatasets
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
def seed_all(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

seed_all(P['SEED'])

In [None]:
try: # detect TPUs
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except ValueError: # no TPU found, detect GPUs
    #strategy = tf.distribute.MirroredStrategy() # for GPU or multi-GPU machines
    strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
    #strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() # for clusters of multi-GPU machines

BATCH_SIZE = P['BATCH_COE'] * strategy.num_replicas_in_sync

print("Number of accelerators: ", strategy.num_replicas_in_sync)
print("BATCH_SIZE: ", str(BATCH_SIZE))

## GCS_PATHS

In [None]:
GCS_PATH = KaggleDatasets().get_gcs_path('hubmap-data-1024-512-tfrecord-soft-mask')
ALL_TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/*.tfrec')

ALL_TRAINING_FILENAMES

In [None]:
import re
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)
print('NUM_TRAINING_IMAGES:' )
if P['OVERLAPP']:
    print(count_data_items(ALL_TRAINING_FILENAMES2)+count_data_items(ALL_TRAINING_FILENAMES))
else:
    print(count_data_items(ALL_TRAINING_FILENAMES))

# Datasets pipeline

In [None]:
# https://www.kaggle.com/kool777/training-hubmap-eda-tf-keras-tpu?scriptVersionId=54288693&cellId=16
def make_mask(num_holes,side_length,rows, cols, num_channels):
    
    """Builds the mask for all sprinkles."""
    
    row_range = tf.tile(tf.range(rows)[..., tf.newaxis], [1, num_holes])
    col_range = tf.tile(tf.range(cols)[..., tf.newaxis], [1, num_holes])
    r_idx = tf.random.uniform([num_holes], minval=0, maxval=rows-1,
                              dtype=tf.int32)
    c_idx = tf.random.uniform([num_holes], minval=0, maxval=cols-1,
                              dtype=tf.int32)
    r1 = tf.clip_by_value(r_idx - side_length // 2, 0, rows)
    r2 = tf.clip_by_value(r_idx + side_length // 2, 0, rows)
    c1 = tf.clip_by_value(c_idx - side_length // 2, 0, cols)
    c2 = tf.clip_by_value(c_idx + side_length // 2, 0, cols)
    row_mask = (row_range > r1) & (row_range < r2)
    col_mask = (col_range > c1) & (col_range < c2)

    # Combine masks into one layer and duplicate over channels.
    mask = row_mask[:, tf.newaxis] & col_mask
    mask = tf.reduce_any(mask, axis=-1)
    mask = mask[..., tf.newaxis]
    mask = tf.tile(mask, [1, 1, num_channels])
    return mask
    
def sprinkles(image): 
    num_holes = 20
    side_length = 15
    mode = 'normal'
    PROBABILITY = 1
    
    RandProb = tf.cast( tf.random.uniform([],0,1) < PROBABILITY, tf.int32)
    if (RandProb == 0)|(num_holes == 0): return image
    
    img_shape = tf.shape(image)
    if mode is 'normal':
        rejected = tf.zeros_like(image)
    elif mode is 'salt_pepper':
        num_holes = num_holes // 2
        rejected_high = tf.ones_like(image)
        rejected_low = tf.zeros_like(image)
    elif mode is 'gaussian':
        rejected = tf.random.normal(img_shape, dtype=tf.float32)
    else:
        raise ValueError(f'Unknown mode "{mode}" given.')
        
    rows = img_shape[0]
    cols = img_shape[1]
    num_channels = img_shape[-1]
    if mode is 'salt_pepper':
        mask1 = make_mask(num_holes,side_length,rows, cols, num_channels)
        mask2 = make_mask(num_holes,side_length,rows, cols, num_channels)
        filtered_image = tf.where(mask1, rejected_high, image)
        filtered_image = tf.where(mask2, rejected_low, filtered_image)
    else:
        mask = make_mask(num_holes,side_length,rows, cols, num_channels)
        filtered_image = tf.where(mask, rejected, image)
    return filtered_image

def transform_shear(image, height, shear, mask=False):
    
    '''
    shear augmentation on image
    and mask.
    --------------------------------
    
    Arguments:
    image -- input image
    mask -- input mask
    
    Return:
    image -- augmented image 
    mask -- augmented mask
    '''
    
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack([DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    if mask:
        return tf.reshape(d, [DIM,DIM,1])
    
    return tf.reshape(d, [DIM,DIM,3])

def transform_shift(image, height, h_shift, w_shift, mask=False):
    
    '''
    shift augmentation on image
    and mask.
    --------------------------------
    
    Arguments:
    image -- input image
    mask -- input mask
    
    Return:
    image -- augmented image 
    mask -- augmented mask
    '''
    
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    height_shift = h_shift * tf.random.uniform([1],dtype='float32') 
    width_shift = w_shift * tf.random.uniform([1],dtype='float32') 
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
        
    # SHIFT MATRIX
    shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack([DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    if mask:
        return tf.reshape(d, [DIM,DIM,1])
    
    return tf.reshape(d, [DIM,DIM,3])

def augmentations(image, mask):
    
    '''
    Apply different augmentations on 
    image and mask.
    --------------------------------
    
    Arguments:
    image -- input image
    mask -- input mask
    
    Return:
    image -- augmented image 
    mask -- augmented mask
    '''
    
    spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    drop_coarse = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # SPATIAL-LEVEL TRANSFORMATIONS
    ## FLIP LEFT-RIGHT
    if spatial >= .4:
        image = tf.image.flip_left_right(image)
        mask = tf.image.flip_left_right(mask)
    
    ## FLIP UP-DOWN
    if spatial >= .5:   
        image = tf.image.flip_up_down(image)
        mask = tf.image.flip_up_down(mask)
        
    ## ROTATIONS
    if rotate > .8:
        image = tf.image.rot90(image, k=3) # rotate 270º
        mask = tf.image.rot90(mask, k=3) # rotate 270º
    elif rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
        mask = tf.image.rot90(mask, k=2) # rotate 180º
    elif rotate > .4:
        image = tf.image.rot90(image, k=1) # rotate 90º
        mask = tf.image.rot90(mask, k=1) # rotate 90º
    
#     ## SHEAR 
#     if shear >= .5:
#         image = transform_shear(image, height=P['DIM'], shear=20.)
#         mask = transform_shear(mask, height=P['DIM'], shear=20., mask=True)
    
#     ## SHIFT
#     if shift >= .5:
#         image = transform_shift(image, height=P['DIM'], h_shift=15., w_shift=15.)
#         mask = transform_shift(mask, height=P['DIM'], h_shift=15., w_shift=15., mask=True)

    ## COARSE-DROPOUT
    if drop_coarse >= .5:
        image = sprinkles(image)
        mask = sprinkles(mask)
    
    # PIXEL-LEVEL TRANSFORMATION
    if pixel >= .2:
        
        if pixel >= .7:
            image = tf.image.random_brightness(image, .2)
        elif pixel >= .6:
            image = tf.image.random_hue(image, .2)
        elif pixel >= .5:
            image = tf.image.random_contrast(image, 0.8, 1.2)
        elif pixel >= .4:
            image = tf.image.random_saturation(image, 0.7, 1.3)
        
    return image, mask

In [None]:
DIM = P['DIM']
def _parse_image_function(example_proto,augment = True, soft_label=True):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'mask': tf.io.FixedLenFeature([], tf.string),
        'soft_mask': tf.io.FixedLenFeature([], tf.string)
    }
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape( tf.io.decode_raw(single_example['image'],out_type=np.dtype('uint8')), (DIM,DIM, 3))
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'],out_type='bool'),(DIM,DIM,1))
    soft_mask =  tf.reshape(tf.io.decode_raw(single_example['soft_mask'],out_type=np.dtype('float32')),(DIM,DIM,1))
    
#     mini_size = 640
#     image = tf.image.resize(image,(mini_size,mini_size))
#     mask = tf.image.resize(tf.cast(mask,'uint8'),(mini_size,mini_size))

#     image = tfio.experimental.color.rgb_to_bgr(image)

    if soft_label:    
        new_mask = P['SOFT_PROB'] * tf.cast(mask, tf.float32) + (1 - P['SOFT_PROB']) * tf.cast(soft_mask,tf.float32)
    else:
        new_mask = mask
        
    
    if augment: # https://www.kaggle.com/kool777/training-hubmap-eda-tf-keras-tpu

        image, new_mask = augmentations(image, new_mask)
        
    
    return tf.cast(image, tf.float32)/255.0, tf.cast(new_mask, tf.float32)

def load_dataset(filenames, ordered=False, augment = True, soft_label=True):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda ex: _parse_image_function(ex, augment = augment, soft_label = soft_label), num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, augment = True, soft_label = True)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(1024, seed = P['SEED'])
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=False)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset(ordered=True):
    dataset = load_dataset(VALIDATION_FILENAMES, ordered=ordered, augment = False, soft_label = True)
    dataset = dataset.batch(BATCH_SIZE,drop_remainder=True)
    #dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

EDA

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from skimage.segmentation import mark_boundaries

TRAINING_FILENAMES = ALL_TRAINING_FILENAMES[0]
print(TRAINING_FILENAMES)
# BATCH_SIZE = 64
train_data = get_training_dataset()

# print(len(train_data))
for i, (imgs, masks) in enumerate(train_data):
    if i ==9:
        break


print(imgs.shape)

M = 20
N=4

plt.figure(figsize = (M,M))
gs1 = gridspec.GridSpec(N*2,N*2)

for i in range(N*N):
   # i = i + 1 # grid spec indexes from 0
    ax1 = plt.subplot(gs1[i*2])
    plt.axis('on')
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_aspect('equal')
    img = imgs[i]
    mask = masks[i]
    ax1.imshow(img)
    
    ax2 = plt.subplot(gs1[i*2+1])
    plt.axis('on')
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_aspect('equal')
    
    ax2.imshow(mask)

plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from skimage.segmentation import mark_boundaries

VALIDATION_FILENAMES = ALL_TRAINING_FILENAMES[0]
print(VALIDATION_FILENAMES)
# BATCH_SIZE = 64
valid_data = get_validation_dataset()

# print(len(train_data))
for i ,(imgs, masks) in enumerate(valid_data):
    if i == 9:
        break
    

print(imgs.shape)

M = 20
N=4

plt.figure(figsize = (M,M))
gs1 = gridspec.GridSpec(N*2,N*2)

for i in range(N*N):
   # i = i + 1 # grid spec indexes from 0
    ax1 = plt.subplot(gs1[i*2])
    plt.axis('on')
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.set_aspect('equal')
    img = imgs[i]
    mask = masks[i]
    ax1.imshow(img)
    
    ax2 = plt.subplot(gs1[i*2+1])
    plt.axis('on')
    ax2.set_xticklabels([])
    ax2.set_yticklabels([])
    ax2.set_aspect('equal')
    
    ax2.imshow(mask)

plt.show()

# Model

In [None]:
# https://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/cost.html#dice_coe
def dice_coe(target,output, axis = None, smooth=1e-10):
    output = tf.dtypes.cast( tf.math.greater(output, 0.5), tf. float32 )
    target = tf.dtypes.cast(target, tf. float32 )
    inse = tf.reduce_sum(output * target, axis=axis)
    l = tf.reduce_sum(output, axis=axis)
    r = tf.reduce_sum(target, axis=axis)

    dice = (2. * inse + smooth) / (l + r + smooth)
    dice = tf.reduce_mean(dice, name='dice_coe')
    return dice

# https://www.kaggle.com/kool777/training-hubmap-eda-tf-keras-tpu
def tversky(y_true, y_pred, alpha=0.7, beta=0.3, smooth=1):
    y_true_pos = K.flatten(y_true)
    y_pred_pos = K.flatten(y_pred)
    true_pos = K.sum(y_true_pos * y_pred_pos)
    false_neg = K.sum(y_true_pos * (1 - y_pred_pos))
    false_pos = K.sum((1 - y_true_pos) * y_pred_pos)
    return (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
def tversky_loss(y_true, y_pred):
    return 1 - tversky(y_true, y_pred)
def focal_tversky_loss(y_true, y_pred, gamma=0.75):
    tv = tversky(y_true, y_pred)
    return K.pow((1 - tv), gamma)

get_custom_objects().update({"focal_tversky": focal_tversky_loss})

In [None]:
# https://github.com/vgarshin/kaggle_kidney/blob/master/kidney_train.ipynb
from tensorflow.keras.losses import binary_crossentropy
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2 * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred, smooth=1):
    return (1 - dice_coef(y_true, y_pred, smooth))

def bce_dice_loss(y_true, y_pred):
    return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)

# https://www.kaggle.com/elcaiseri/hubmap-pytorch-starter-vit-segmentation-train
from tensorflow.keras.losses import Loss
from segmentation_models.losses import DiceLoss
class DiceBCELoss(Loss):
    # Formula Given above.
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def call(self,y_true, y_pred):
        
        return bce_dice_loss(y_true, y_pred)
#         return bce_jaccard_loss(y_true, y_pred)

# Model fit

In [None]:
if P['WANDB']:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_KEY")

    !pip install wandb==0.10.10
    !wandb login $secret_value_0
    
    
    import wandb
    from wandb.keras import WandbCallback
    wandb.init(config=P, project='hubmap-hacking-the-kidney')
#     wandb.init(config=P, project='hubmap-hacking-the-kidney')
    

In [None]:
# fold = KFold(n_splits=P['NFOLDS'], shuffle=True, random_state=P['SEED'])
# for fold,(tr_idx, val_idx) in enumerate(fold.split(ALL_TRAINING_FILENAMES)):
#     print(*tr_idx, sep=', ')
#     print(*val_idx, sep=', ')

In [None]:
trs_idx = np.array([[0, 2, 3, 5, 7, 8, 9, 10, 11, 12, 13, 14],
                   [0, 1, 2, 3, 4, 5, 6, 8, 10, 12, 13, 14],
                   [0, 1, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14],
                   [0, 1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13],
                   [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 13, 14]])

vals_idx = np.array([[1, 6, 4],
                    [7, 9, 11],
                    [2, 3, 13],
                    [5, 8, 14],
                    [0, 10, 12]])



# for fold,(tr_idx, val_idx) in zip(P['FOLDS'],zip(trs_idx[P['FOLDS']],vals_idx[P['FOLDS']])):
#     print(fold)
#     print(*tr_idx, sep=', ')
#     print(*val_idx, sep=', ')

In [None]:
M = {}
metrics = ['loss','dice_coef','accuracy']
for fm in metrics:
    M['val_'+fm] = []

fold = KFold(n_splits=P['NFOLDS'], shuffle=True, random_state=P['SEED'])
# for fold,(tr_idx, val_idx) in zip(P['FOLDS'],zip(trs_idx[P['FOLDS']],vals_idx[P['FOLDS']])):
for fold,(tr_idx, val_idx) in enumerate(fold.split(ALL_TRAINING_FILENAMES)): 
    if fold not in P['FOLDS']:
        continue
        
    print('#'*35); print('############ FOLD ',fold,' #############'); print('#'*35);
    print(f'Image Size: {DIM}, Batch Size: {BATCH_SIZE}')
    print(f'Valid File Number: {val_idx}' )
    
    # CREATE TRAIN AND VALIDATION SUBSETS
    TRAINING_FILENAMES = [ALL_TRAINING_FILENAMES[fi] for fi in tr_idx]
    
    VALIDATION_FILENAMES = [ALL_TRAINING_FILENAMES[fi] for fi in val_idx]
    
    print(f'Valid File Name: {VALIDATION_FILENAMES}')
    
#     if P['OVERLAPP']:
#         VALIDATION_FILENAMES += [ALL_TRAINING_FILENAMES2[fi] for fi in val_idx]
        
    
    STEPS_PER_EPOCH = P['STEPS_COE'] * count_data_items(TRAINING_FILENAMES) // BATCH_SIZE
    
    # BUILD MODEL
    K.clear_session()
    with strategy.scope():   
#         loss_obj = DiceBCELoss()
#         input_shape=(P['DIM'],P['DIM'],3)
        model = sm.Unet(P['BACKBONE'], encoder_weights='imagenet', classes=1, activation='sigmoid')
        model.compile(optimizer = tf.keras.optimizers.Adam(lr = P['LR']),
                      loss = DiceBCELoss(),
                      metrics=[dice_coef,'accuracy'])
        
    # CALLBACKS
    checkpoint = tf.keras.callbacks.ModelCheckpoint(f"/kaggle/working/{P['BACKBONE']}_Unet_model_fold_{fold}_sm_ex_sl.h5",
                                                    verbose=P['VERBOSE'],monitor='val_dice_coef',
                                                    mode='max',save_best_only=True)
    
    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_dice_coef',mode = 'max',
                                                  patience=13, restore_best_weights=True)
    
    reduce = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_dice_coef', mode = 'max', factor=0.2,
                                                  patience=4, min_lr=1e-8,min_delta=0.0001,
                                                 verbose=1)
    
    if P['WANDB']:
        wandb.run.name= f"{P['BACKBONE']}_fold_{fold}"
        wandb.watch_called = False
        call_back = [WandbCallback(), checkpoint, reduce, early_stop]
    else:
        call_back = [checkpoint, reduce, early_stop]
        
    preprocess_input = sm.get_preprocessing(P['BACKBONE'])
    
    print(f'Training Model Fold {fold}...')
    history = model.fit(
        get_training_dataset(),
        epochs = P['EPOCHS'],
        steps_per_epoch = STEPS_PER_EPOCH,
        callbacks = call_back,
        validation_data = get_validation_dataset(),
        verbose=1
    )   
    
    #with strategy.scope():
    #    model = tf.keras.models.load_model('/kaggle/working/model-fold-%i.h5'%fold, custom_objects = {"dice_coe": dice_coe})
    
    # SAVE METRICS
    m = model.evaluate(get_validation_dataset(),return_dict=True)
    for fm in metrics:
        M['val_'+fm].append(m[fm])
    
    # PLOT TRAINING
    # https://www.kaggle.com/cdeotte/triple-stratified-kfold-with-tfrecords
    if P['DISPLAY_PLOT']:        
        plt.figure(figsize=(15,5))
        n_e = np.arange(len(history.history['dice_coef']))
        plt.plot(n_e,history.history['dice_coef'],'-o',label='Train dice_coef',color='#ff7f0e')
        plt.plot(n_e,history.history['val_dice_coef'],'-o',label='Val dice_coef',color='#1f77b4')
        x = np.argmax( history.history['val_dice_coef'] ); y = np.max( history.history['val_dice_coef'] )
        xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max dice_coef\n%.2f'%y,size=14)
        plt.ylabel('dice_coef',size=14); plt.xlabel('Epoch',size=14)
        plt.legend(loc=2)
        plt2 = plt.gca().twinx()
        plt2.plot(n_e,history.history['loss'],'-o',label='Train Loss',color='#2ca02c')
        plt2.plot(n_e,history.history['val_loss'],'-o',label='Val Loss',color='#d62728')
        x = np.argmin( history.history['val_loss'] ); y = np.min( history.history['val_loss'] )
        ydist = plt.ylim()[1] - plt.ylim()[0]
        plt.scatter(x,y,s=200,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
        plt.ylabel('Loss',size=14)
        plt.legend(loc=3)
        plt.show()

In [None]:
### WRITE METRICS
import json
from datetime import datetime
M['datetime'] = str(datetime.now())
for fm in metrics:
    M['oof_'+fm] = np.mean(M['val_'+fm])
    print('OOF '+ fm + ' '+ str(M['oof_'+fm]))
with open('metrics.json', 'w') as outfile:
    json.dump(M, outfile)

In [None]:
!rm -r wandb

[1,2,12]
[4,5,9]