![](https://hubmapconsortium.github.io/ccf/img/img-spatial-steps-registration-01.png)

<p style='text-align: center;'><span style="color: #737373; font-family: Segoe UI; font-size: 0.8em;">Figure: Spatial placement of anatomical structures in relation to the HuBMAP Atlas reference system.</span></p>
</p>

<span style="color: #0D0D0D; font-family: Segoe UI; font-size: 2.8em; font-weight: 400;">RESUME TPU KERAS TRAINING FOR LARGE MODELS</span>

<p style='text-align: justify;'><span style="color: #005c68; font-family: Segoe UI; font-size: 1.5em;">This notebook contains steps to resume TPU training from previous saved model weights. This is useful for large models which needs to be trained more than the cap limit per session in Kaggle Notebooks</span></p>


* <p style='text-align: justify;'><span style="color: #000A0C; font-family: Segoe UI; font-size: 1.2em;">The model weights and the callback files are automatically uploaded to GCS at the end of every epoch by means of a Custom Callback Model Checkpoint Function.</span></p>
    
    

* <p style='text-align: justify;'><span style="color: #000A0C; font-family: Segoe UI; font-size: 1.2em;">On the first run, a GCS bucket is created and session files are stored here. On any subsequent runs, the latest model weights and callback files in the GCS bucket are used to resume training.</span></p>



* <p style='text-align: justify;'><span style="color: #000A0C; font-family: Segoe UI; font-size: 1.2em;">Once the session expires, simply re-running the whole notebook again will resume the training from the last saved epoch.</span></p>


<span style="color: #005c68; font-family: Segoe UI; font-size: 1.6em;">Useful Datasets:</span>


<span style="color: #000A0C; font-family: Segoe UI; font-size: 1.1em;">HubMAP TFRecords With Augmentation Dataset is being used in this notebook. It has TFRecords with actual and augmented images grouped into 90-130MB records fit for TPU loads.</span>

 - HubMAP TFRecords With Augmentation Dataset: https://www.kaggle.com/sreevishnudamodaran/hubmap-512x512-tfrecords-with-aug


Created from **Augmented images 512x512 tiled Dataset** which has augmented images and masks.

 - HuBMAP Augmented 512x512: https://www.kaggle.com/sreevishnudamodaran/hubmap-512x512-augmented

<span style="color: #000A0C; font-family: Segoe UI; font-size: 1.2em;">Please take a look at my previous notebook on creating these datasets and building a Double U-net model:</span>

 - TPU - HubMAP Double U-Net Model + Augmentation : https://www.kaggle.com/sreevishnudamodaran/tpu-hubmap-double-u-net-model-augmentation

<p style='text-align: justify;'><span style="color: #d14800; font-family: Segoe UI; font-size: 1.5em;">I would like to highlight that we should make use of TPUs judiciously and make sure that it is available for everyone. Thanks to the Kaggle team for such great perks!</span></p>

<span style="color: #005c68; font-family: Segoe UI; font-size: 1.6em;">Note:</span>
*  <p style='text-align: justify;'><span style="color: #001a1d; font-family: Segoe UI; font-size: 1.1em;font-weight: 500;">Please note that it may not be possible to recreate the exact conditions of previous training session as it is not possible to keep track of the state of certain callbacks at present. Pickling of such callbacks are not supported at the moment due to thread.R_lock, thread.lock and lambda fn objects in the tensorflow code.</span></p>
 
 
*  <p style='text-align: justify;'><span style="color: #001a1d; font-family: Segoe UI; font-size: 1.1em;font-weight: 500;">For resuming training, a GCS Storage Bucket will be created/used as a part of this notebook. So a GCP account and a project has to be setup prior to running the notebook. The Google Cloud SDK and Google Cloud Services also has to be enabled from the Add-ons menu.</span></p>

 
 

[![Ask Me Anything !](https://img.shields.io/badge/Ask%20me-anything-1abc9c.svg?style=flat-square&logo=appveyor)](https://www.kaggle.com/sreevishnudamodaran)



![TPU!](https://img.shields.io/badge/Accelerator-TPU-green?style=flat-square&logo=appveyor)

![Upvote!](https://img.shields.io/badge/Upvote-If%20you%20like%20my%20work-blue?style=for-the-badge&logo=appveyor)

## Import Libraries

In [None]:
%matplotlib inline

import json
import os
import glob
import re
import datetime
import os.path as osp
from path import Path
import collections
import sys
import uuid
import random
import warnings
from itertools import chain
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
sns.set(rc={"font.size":9,"axes.titlesize":15,"axes.labelsize":9,
            "axes.titlepad":2, "axes.labelpad":9, "legend.fontsize":7,
            "legend.title_fontsize":7, 'axes.grid' : False,
           'figure.titlesize':35
           
           })

from PIL import Image
import cv2

import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.layers import MaxPooling2D, UpSampling2D, Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.losses import binary_crossentropy

from kaggle_datasets import KaggleDatasets
from kaggle_secrets import UserSecretsClient

## Intialize and Get TPU Ready

In [None]:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)

strategy = tf.distribute.experimental.TPUStrategy(tpu)


In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

## Load TFRecords & View Samples

In [None]:
AUTO = tf.data.experimental.AUTOTUNE

image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_and_masks_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    img_bytes =  tf.io.decode_raw(single_example['img_bytes'], out_type='uint8')
    img_array = tf.reshape(img_bytes, (512, 512, 3))
    mask_bytes =  tf.io.decode_raw(single_example['mask'], out_type='bool')
    mask = tf.reshape(mask_bytes, (512, 512, 1))
    
    ## normalize images array and cast image and mask to float32
#     img_array = tf.cast(img_array, tf.float32) / 255.0
#     mask = tf.cast(mask, tf.float32)
    return img_array, mask

def read_dataset(storage_file_path):
    encoded_image_dataset = tf.data.TFRecordDataset(storage_file_path, compression_type="GZIP")
    parsed_image_dataset = encoded_image_dataset.map(_parse_image_and_masks_function)
    return parsed_image_dataset

In [None]:
# Get the credential from the Cloud SDK
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()

# Set the credentials
user_secrets.set_tensorflow_credential(user_credential)

# Use a familiar call to get the GCS path of the dataset
from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('hubmap-512x512-tfrecords-with-aug')
GCS_DS_PATH

In [None]:
train_tf_gcs = GCS_DS_PATH+'/train/*.tfrecords'
val_tf_gcs = GCS_DS_PATH+'/val/*.tfrecords'
train_tf_files = tf.io.gfile.glob(train_tf_gcs)
val_tf_files = tf.io.gfile.glob(val_tf_gcs)
print(val_tf_files[:3])
print("Train TFrecord Files:", len(train_tf_files))
print("Val TFrecord Files:", len(val_tf_files))

In [None]:
train_dataset = read_dataset(train_tf_files[15])
validation_dataset = read_dataset(val_tf_files[15])

train_image = []
train_mask =[]
for image, mask in train_dataset.take(5):
    train_image, train_mask = image, mask
train_mask = np.squeeze(train_mask)
    
test_image = []
test_mask =[]
for image, mask in validation_dataset.take(5):
    test_image, test_mask = image, mask
test_mask = np.squeeze(test_mask)
    
fig, ax = plt.subplots(2,2,figsize=(20,10))
ax[0][0].imshow(train_image)
ax[0][1].imshow(train_mask)
ax[1][0].imshow(test_image)
ax[1][1].imshow(test_mask)

## Model Building

In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.applications import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam, Nadam
from tensorflow.keras.metrics import *
from tensorflow.keras.losses import binary_crossentropy

np.random.seed(13)
tf.random.set_seed(13)

In [None]:
def squeeze_excite_block(inputs, ratio=8):
    init = inputs
    channel_axis = -1
    filters = init.shape[channel_axis]
    se_shape = (1, 1, filters)

    se = GlobalAveragePooling2D()(init)
    se = Reshape(se_shape)(se)
    se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
    se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)

    x = Multiply()([init, se])
    return x

def conv_block(inputs, filters):
    x = inputs

    x = Conv2D(filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = Conv2D(filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    x = squeeze_excite_block(x)

    return x

def encoder1(inputs):
    skip_connections = []

    model = VGG19(include_top=False, weights='imagenet', input_tensor=inputs)
    names = ["block1_conv2", "block2_conv2", "block3_conv4", "block4_conv4"]
    for name in names:
        skip_connections.append(model.get_layer(name).output)

    output = model.get_layer("block5_conv4").output
    return output, skip_connections

def decoder1(inputs, skip_connections):
    num_filters = [256, 128, 64, 32]
    skip_connections.reverse()
    x = inputs
    shape = x.shape

    for i, f in enumerate(num_filters):
        x = Conv2DTranspose(shape[3], (2, 2), activation="relu", strides=(2, 2))(x)
        x = Concatenate()([x, skip_connections[i]])
        x = conv_block(x, f)

    return x

def encoder2(inputs):
    num_filters = [32, 64, 128, 256]
    skip_connections = []
    x = inputs

    for i, f in enumerate(num_filters):
        x = conv_block(x, f)
        skip_connections.append(x)
        x = MaxPool2D((2, 2))(x)

    return x, skip_connections

def decoder2(inputs, skip_1, skip_2):
    num_filters = [256, 128, 64, 32]
    skip_2.reverse()
    x = inputs
    shape = x.shape

    for i, f in enumerate(num_filters):
        x = Conv2DTranspose(shape[3], (2, 2), activation="relu", strides=(2, 2))(x)
        x = Concatenate()([x, skip_1[i], skip_2[i]])
        x = conv_block(x, f)

    return x

def output_block(inputs):
    x = Conv2D(1, (1, 1), padding="same")(inputs)
    x = Activation('sigmoid')(x)
    return x

def ASPP(x, filter):
    shape = x.shape

    y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(x)
    y1 = Conv2D(filter, 1, padding="same")(y1)
    y1 = BatchNormalization()(y1)
    y1 = Activation("relu")(y1)
    shape2 = y1.shape
    
    y1 = Conv2DTranspose(shape2[3], (8,8), activation="relu", strides=(shape[1], shape[2]))(y1)
    

    y2 = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(x)
    y2 = BatchNormalization()(y2)
    y2 = Activation("relu")(y2)

    y3 = Conv2D(filter, 3, dilation_rate=6, padding="same", use_bias=False)(x)
    y3 = BatchNormalization()(y3)
    y3 = Activation("relu")(y3)

    y4 = Conv2D(filter, 3, dilation_rate=12, padding="same", use_bias=False)(x)
    y4 = BatchNormalization()(y4)
    y4 = Activation("relu")(y4)

    y5 = Conv2D(filter, 3, dilation_rate=18, padding="same", use_bias=False)(x)
    y5 = BatchNormalization()(y5)
    y5 = Activation("relu")(y5)

    y = Concatenate()([y1, y2, y3, y4, y5])

    y = Conv2D(filter, 1, dilation_rate=1, padding="same", use_bias=False)(y)
    y = BatchNormalization()(y)
    y = Activation("relu")(y)

    return y

def build_model():
    inputs = Input((512, 512, 3))
    x, skip_1 = encoder1(inputs)
    x = ASPP(x, 64)
    x = decoder1(x, skip_1)
    outputs1 = output_block(x)

    x = inputs * outputs1

    x, skip_2 = encoder2(x)
    x = ASPP(x, 64)
    x = decoder2(x, skip_1, skip_2)
    outputs2 = output_block(x)
    outputs = Concatenate()([outputs1, outputs2])
    
    combine_output = Conv2D(1, (64, 64), activation="sigmoid", padding="same")(outputs)

    model = Model(inputs, combine_output)
    return model

In [None]:
model = build_model()
model.summary(line_length=150)

## Prepare Training Job Folder

In [None]:
train_job_path = '/kaggle/working/train_job'
if not os.path.exists(train_job_path):
    os.makedirs(train_job_path)

In [None]:
from google.cloud import storage
storage_client = storage.Client(project='placesproject-284409')

with strategy.scope():

    def create_bucket(dataset_name):
        """Creates a new bucket. https://cloud.google.com/storage/docs/ """
        bucket = storage_client.create_bucket(dataset_name)
        print('Bucket {} created'.format(bucket.name))

    def upload_blob(bucket_name, source_file_name, destination_blob_name):
        """Uploads a file to the bucket. https://cloud.google.com/storage/docs/ """
        try:
            bucket = storage_client.get_bucket(bucket_name)
            blob = bucket.blob(destination_blob_name)
            blob.upload_from_filename(source_file_name)
        except Exception as E:
            print("Error uploading:", E)
    #     print('File {} uploaded to {}.'.format(
    #         source_file_name,
    #         destination_blob_name))

    def list_blobs(bucket_name):
        """Lists all the blobs in the bucket. https://cloud.google.com/storage/docs/"""
        blob_list = []
        blobs = storage_client.list_blobs(bucket_name)
        for blob in blobs:
            blob_list.append(blob.name)
        #print(blob_list)
        return blob_list

    def download_to_kaggle(bucket_name,destination_directory,file_name):
        """Takes the data from your GCS Bucket and puts it into the working directory of your Kaggle notebook"""
        os.makedirs(destination_directory, exist_ok = True)
        full_file_path = os.path.join(destination_directory, file_name)
        blobs = storage_client.list_blobs(bucket_name)
        for blob in blobs:
            blob.download_to_filename(full_file_path)

In [None]:
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

## Using GCS Bucket Persistant Job File Storage 
### Create a Bucket if it dosen't already exist

In [None]:
bucket_name = 'hubmap_train_job_6'         
try:
    create_bucket(bucket_name)   
except:
    pass

## Download the Model Weights and Callback Files

### **If Model Weights and Callback Files exist from a previous session, download them to resume training from that point.**

In [None]:
checkpoints_path = '/kaggle/working/train_job/checkpoints'
if not os.path.exists(checkpoints_path):
    os.makedirs(checkpoints_path)
    
files_list = list_blobs(bucket_name)
ckpt_list =  [file for file in files_list if '.hdf5' in file]

if ckpt_list:
    download_to_kaggle(bucket_name, train_job_path, 'lr_value.pickle')
    download_to_kaggle(bucket_name, checkpoints_path, ckpt_list[-1])
    checkpoint_path = os.path.join(checkpoints_path, ckpt_list[-1])
else:
    checkpoint_path = None

## Define Metrics

In [None]:
with strategy.scope():
    def dice_coeff(y_true, y_pred):
        # add epsilon to avoid a divide by 0 error in case a slice has no pixels set
        # we only care about relative value, not absolute so this alteration doesn't matter
        _epsilon = 10 ** -7
        intersections = tf.reduce_sum(y_true * y_pred)
        unions = tf.reduce_sum(y_true + y_pred)
        dice_scores = (2.0 * intersections + _epsilon) / (unions + _epsilon)
        return dice_scores

    def dice_loss(y_true, y_pred):
        loss = 1 - dice_coeff(y_true, y_pred)
        return loss
    
    def bce_dice_loss(y_true, y_pred):
        return binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    
    def tversky(y_true, y_pred, smooth=1, alpha=0.7):
        y_true_pos = tf.reshape(y_true,[-1])
        y_pred_pos = tf.reshape(y_pred,[-1])
        true_pos = tf.reduce_sum(y_true_pos * y_pred_pos)
        false_neg = tf.reduce_sum(y_true_pos * (1 - y_pred_pos))
        false_pos = tf.reduce_sum((1 - y_true_pos) * y_pred_pos)
        return (true_pos + smooth) / (true_pos + alpha * false_neg + (1 - alpha) * false_pos + smooth)

    def tversky_loss(y_true, y_pred):
        return 1 - tversky(y_true, y_pred)

    def focal_tversky_loss(y_true, y_pred, gamma=0.75):
        tv = tversky(y_true, y_pred)
        return K.pow((1 - tv), gamma)

    get_custom_objects().update({"dice": dice_loss})

## Functions to Load Records

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
image_feature_description = {
    'height': tf.io.FixedLenFeature([], tf.int64),
    'width': tf.io.FixedLenFeature([], tf.int64),
    'num_channels': tf.io.FixedLenFeature([], tf.int64),
    'img_bytes': tf.io.FixedLenFeature([], tf.string),
    'mask': tf.io.FixedLenFeature([], tf.string),
}

def _parse_image_function(example_proto):
    single_example = tf.io.parse_single_example(example_proto, image_feature_description)
    image = tf.reshape( tf.io.decode_raw(single_example['img_bytes'],out_type='uint8'), (512, 512, 3))
    mask =  tf.reshape(tf.io.decode_raw(single_example['mask'],out_type='bool'),(512, 512, 1))
    ## normalize images array and cast image and mask to float32
    image = tf.cast(image, tf.float32) / 255.0
    mask = tf.cast(mask, tf.float32)
    return image, mask

def load_dataset(filenames, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO, compression_type="GZIP")
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(_parse_image_function, num_parallel_calls=AUTO)
    return dataset

def get_training_dataset():
    dataset = load_dataset(train_tf_files)
    #dataset = dataset.repeat()
    dataset = dataset.shuffle(47000)
    dataset = dataset.batch(16, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset
def get_val_dataset():
    dataset = load_dataset(val_tf_files)
    #dataset = dataset.repeat()
    dataset = dataset.shuffle(5000)
    dataset = dataset.batch(16, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset

## Custom ModelCheckpoint Funtion
### Upload relevant weights and callback files to GCS for persistance


When training resumes, the process may not start with the same conditions that took place when the checkpoint was saved. The learning rate would restart from its initial value as we may use learning rate decay or learning rate reduction on plateau. Thus, we ensure that here by pickling the learning rate at every epoch and persisting it in GCS.

    Training Samples         : 47703
    Validation Samples       : 5829

In [None]:
from keras.callbacks import ModelCheckpoint
import pickle

with strategy.scope():

    class ModelCheckpointEnhanced(ModelCheckpoint):
        def __init__(self, *args, **kwargs):
            # Added arguments
            self.lr_epoch_path = kwargs.pop('lr_epoch_path')
            
            super().__init__(*args, **kwargs)
            self.model_filepath = kwargs.pop('filepath')

        def on_epoch_end(self, epoch, logs=None):
            # Run normal flow:
            super().on_epoch_end(epoch,logs)

            model_filepath = self.model_filepath.format(epoch=epoch + 1, **logs)
            print("\nModel_filepath", model_filepath)
            lr_val = float(K.get_value(self.model.optimizer.lr))
            if self.epochs_since_last_save == 0 and epoch!=0:
                if self.save_best_only:
                    current = logs.get(self.monitor)
                    if current == self.best:
                        # Note, there might be some cases where the last statement will save on unwanted epochs.
                        # However, in the usual case where your monitoring value space is continuous this is not likely
                        if os.path.exists(self.lr_epoch_path):
                            os.remove(self.lr_epoch_path)
                        
                        with open(self.lr_epoch_path, "wb") as f:
                            pickle.dump(lr_val, f)
                            f.close()

                        ## Uploading LR Pickle File to GCS
                        file_name = os.path.basename(Path(self.lr_epoch_path))
                        upload_blob(bucket_name, self.lr_epoch_path, file_name)
                        
                        ## Uploading Model Weight File to GCS
                        model_file_name = os.path.basename(Path(model_filepath))
                        upload_blob(bucket_name, model_filepath, model_file_name)         
                else:
                    if os.path.exists(self.lr_epoch_path):
                        os.remove(self.lr_epoch_path)
                        
                    with open(self.lr_epoch_path, "wb") as f:
                        pickle.dump(lr_val, f)
                        f.close()
                    
                    ## Uploading LR Pickle File to GCS
                    file_name = os.path.basename(Path(self.lr_epoch_path))
                    upload_blob(bucket_name, self.lr_epoch_path, file_name)
                    
                    ## Uploading Model Weight File to GCS
                    model_file_name = os.path.basename(Path(model_filepath))
                    upload_blob(bucket_name, model_filepath, model_file_name)
                    
            print('Model Files Uploaded to GCS')

In [None]:
with strategy.scope():  
    def get_init_epoch(checkpoint_path):
        filename = os.path.basename(checkpoint_path)
        #filename = os.path.splitext(filename)[0]
        init_epoch = filename.split("_")[1]
        print("init_epoch", init_epoch)
        return int(init_epoch)
    
    metrics = [
    dice_coeff,
    bce_dice_loss,
    Recall(),
    Precision(),
    tversky_loss,
    focal_tversky_loss
    ]
    
    #Defining metrics as dict to pass to model load function
    metrics_dict = {
                    'dice_coeff': dice_coeff,
                    'bce_dice_loss': bce_dice_loss,
                    'tversky_loss': tversky_loss,
                    'focal_tversky_loss': focal_tversky_loss
    }
    
    
    ## Calling the Custom Checkpoint Callback.
    ## Passing the Weight path and the Learning Rate pickle file path which will be used resume training.
    ckpt_callback = ModelCheckpointEnhanced(filepath=train_job_path+'/checkpoints/weights_{epoch:02d}_{val_loss:.2f}.hdf5',
                                            monitor='val_loss', lr_epoch_path=train_job_path+'/lr_value.pickle')
    
    callbacks = [
        ckpt_callback,
        ReduceLROnPlateau(monitor='val_loss', factor=0.001, patience=3),
        EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=False),
        CSVLogger(train_job_path+"/data.csv")  
    ]
    
    train_dataset = get_training_dataset()
    validation_dataset = get_val_dataset()
    train_steps = round((47703//32)*0.3)
    validation_steps = round((5829//32)*0.3)
#     train_steps = 800
#     validation_steps = 100

    # Load checkpoint:
    if checkpoint_path is not None:
        # Load model:
        print("Resuming training from checkpoint:", checkpoint_path)
        model = load_model(checkpoint_path, custom_objects=metrics_dict)
        
        # Finding the epoch index from which we are resuming
        initial_epoch = get_init_epoch(checkpoint_path)

        loaded_lr = pickle.load(open(train_job_path+'/lr_value.pickle', "rb" ))
        K.set_value(model.optimizer.lr, loaded_lr)
    
    else:
        print("Building model and starting training")
        model = build_model()
        model.compile(optimizer = Adam(lr = 1e-2), loss = 'dice', metrics=metrics)
        initial_epoch = 0

    # Start/resume training
    model.fit(train_dataset, epochs=30, steps_per_epoch=train_steps,
          validation_data=validation_dataset, validation_steps=validation_steps,
          callbacks=callbacks,
          initial_epoch=initial_epoch)

<span style="color: #005c68; font-family: Segoe UI; font-size: 1.6em;">If the session stops, start a new session and just 'Run all' cells to resume training :)</span>