My main goal of this competition is:
* Integrate a custom TensorFlow training loop where image is augmented each epoch
* Create a working end-to-end pipeline for preprocessing and infering the images on TPUs
* I'll be updating this workbook regularly.


What still needs to be implemented:
* Make sure that dataset is reset and is applying the map functin on each step, as tensorflow does not do it automatically.

In [None]:
#install efficientnet keras model
!pip install --quiet efficientnet

In [None]:
!pip uninstall --quiet albumentations -y
#!pip uninstall --quiet tensorflow -y

In [None]:
!pip install --quiet albumentations==0.5.1
#!pip install --quiet tensorflow==2.3.0
#!pip install --quiet cloud-tpu-client

#import tensorflow as tf
#from cloud_tpu_client import Client
#print(tf.__version__)

#Client().configure_tpu_version(tf.__version__, restart_type='ifNeeded')

In [None]:
#import used libraries
import math, os, re, warnings, random, time
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
from sklearn.utils import class_weight
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
from tensorflow.keras import optimizers, applications, Sequential, losses, metrics
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler
import efficientnet.tfkeras as efn
from matplotlib.pyplot import imread
import keras
import albumentations
import functools
from multiprocessing.dummy import Pool

#make sure everything is seeded so the models are reproduceable
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 0
seed_everything(seed)
warnings.filterwarnings('ignore')

In [None]:
# TPU or GPU detection
# Detect hardware, return appropriate distribution strategy
def_strat = tf.distribute.get_strategy()
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
#TPUs use less RAM so the BATCH SIZE can be increased
if tpu:
    BATCH_SIZE = 8 * REPLICAS
else:
    BATCH_SIZE = 4 * REPLICAS
#Maximum Learning Rate
LEARNING_RATE = 1e-5 * REPLICAS
#Maximum number of Epochs in Training
EPOCHS = 18
#The following 4 variables determine the size of the tensors definining the images for training
HEIGHT = 512
WIDTH = 512
IMAGE_SIZE = [HEIGHT,WIDTH]
CHANNELS = 3
#The number of classes we are trying to classify
N_CLASSES = 5
#The number of epochs we wait before stopping the training run after not improving the model
ES_PATIENCE = 4
#Number of KFolds we are doing. e.g. How many models we are training
N_FOLDS = 3
#Number of test time augmentation, used for the test set
N_TTA = 5

ONE_HOT = True

T1 = 0.2
T2 = 1.0
LABEL_SMOOTH = 0.15

In [None]:
#Path to the datasets
database_base_path = '/kaggle/input/cassava-leaf-disease-classification/'

#Read the CSV with the image names and labels associated with them for training
train = pd.read_csv(database_base_path + 'train.csv')
print('Train samples: %d' % len(train))


GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
display(train.head())

TRAIN_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/ld_train*.tfrec')
TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test_tfrecords/ld_test*.tfrec')
TRAIN_IMAGES = os.listdir(database_base_path+'/train_images/')
TEST_IMAGES = os.listdir(database_base_path+'/test_images/')

#Names of our 5 classes for classification
CLASSES = ['Cassava Bacterial Blight', 
           'Cassava Brown Streak Disease', 
           'Cassava Green Mottle', 
           'Cassava Mosaic Disease', 
           'Healthy']

In [None]:
transforms = albumentations.Compose([
    albumentations.Flip(always_apply=False,p=0.9),
    albumentations.RandomResizedCrop(height=HEIGHT,width=WIDTH,always_apply = True),
    albumentations.Blur(blur_limit=7,p=0.3),
    albumentations.ColorJitter(brightness=0.1, contrast=0.15, saturation=0.15, hue=0.0, always_apply=False, p=0.5),
    #albumentations.RandomBrightnessContrast(brightness_limit=0.15,contrast_limit=0.2,brightness_by_max=True,always_apply=False,p=0.3),
    albumentations.Cutout(num_holes=4,max_h_size=int(HEIGHT*0.1),max_w_size=int(WIDTH*0.1),fill_value=0,always_apply=False,p=0.2,),
    albumentations.Cutout(num_holes=4,max_h_size=int(HEIGHT*0.1),max_w_size=int(WIDTH*0.1),fill_value=0,always_apply=False,p=0.2,)])
def aug_fn(image):
    aug_data = transforms(image=image)
    aug_img = aug_data['image']
    aug_img = tf.cast(aug_img,tf.float32)
    return aug_img
def aug_fn_label(image,label):
    aug_data = transforms(image=image)
    aug_img = aug_data['image']
    aug_img = tf.cast(aug_img,tf.float32)
    return aug_img,label
def process_data(image, label):
    aug_img = tf.numpy_function(func=aug_fn, inp=[image], Tout=tf.float32)
    return aug_img, label
def set_shapes(img, label):
    img.set_shape((HEIGHT,WIDTH,CHANNELS))
    label.set_shape([])
    return img, label

In [None]:
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_cutout = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_randcrop = tf.random.uniform([],0,1.0,dtype = tf.float32)
    
    if p_randcrop > .4:
        image = tf.image.random_crop(image, size= [HEIGHT,WIDTH,CHANNELS])
        image = tf.reshape(image,shape = (HEIGHT,WIDTH,CHANNELS))
    else:
        image = tf.image.resize(image, size=[HEIGHT, WIDTH])
        image = tf.reshape(image,shape = (HEIGHT,WIDTH,CHANNELS))
    
    # Shear
    if p_shear > .2:
        if p_shear > .6:
            image = transform_shear(image, HEIGHT, shear=20.)
        else:
            image = transform_shear(image, HEIGHT, shear=-20.)
            
    # Rotation
    if p_rotation > .2:
        if p_rotation > .6:
            image = transform_rotation(image, HEIGHT, rotation=45.)
        else:
            image = transform_rotation(image, HEIGHT, rotation=-45.)
            
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270ยบ
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180ยบ
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90ยบ
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    # Crops
    if p_crop > .6:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.5)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .7:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .3:
        crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])

    if p_cutout > .5:
        image = data_augment_cutout(image)
    image = tf.reshape(image,shape = (HEIGHT,WIDTH,CHANNELS))
    return image, label

In [None]:
# data augmentation @cdeotte kernel: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

# CutOut
def data_augment_cutout(image, min_mask_size=(int(HEIGHT * .1), int(HEIGHT * .1)), 
                        max_mask_size=(int(HEIGHT * .125), int(HEIGHT * .125))):
    p_cutout = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    if p_cutout > .85: # 10~15 cut outs
        n_cutout = tf.random.uniform([], 10, 15, dtype=tf.int32)
        image = random_cutout(image, HEIGHT, WIDTH, 
                              min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
    elif p_cutout > .6: # 5~10 cut outs
        n_cutout = tf.random.uniform([], 5, 10, dtype=tf.int32)
        image = random_cutout(image, HEIGHT, WIDTH, 
                              min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
    elif p_cutout > .25: # 2~5 cut outs
        n_cutout = tf.random.uniform([], 2, 5, dtype=tf.int32)
        image = random_cutout(image, HEIGHT, WIDTH, 
                              min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=n_cutout)
    else: # 1 cut out
        image = random_cutout(image, HEIGHT, WIDTH, 
                              min_mask_size=min_mask_size, max_mask_size=max_mask_size, k=1)

    return image

def random_cutout(image, height, width, channels=3, min_mask_size=(10, 10), max_mask_size=(80, 80), k=1):
    assert height > min_mask_size[0]
    assert width > min_mask_size[1]
    assert height > max_mask_size[0]
    assert width > max_mask_size[1]

    for i in range(k):
        mask_height = tf.random.uniform(shape=[], minval=min_mask_size[0], maxval=max_mask_size[0], dtype=tf.int32)
        mask_width = tf.random.uniform(shape=[], minval=min_mask_size[1], maxval=max_mask_size[1], dtype=tf.int32)

        pad_h = height - mask_height
        pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
        pad_bottom = pad_h - pad_top

        pad_w = width - mask_width
        pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
        pad_right = pad_w - pad_left

        cutout_area = tf.zeros(shape=[mask_height, mask_width, channels], dtype=tf.uint8)

        cutout_mask = tf.pad([cutout_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
        cutout_mask = tf.squeeze(cutout_mask, axis=0)
        image = tf.multiply(tf.cast(image, tf.float32), tf.cast(cutout_mask, tf.float32))

    return image

In [None]:
start_time = time.time()
plt.figure(figsize=(10, 10))
for i in range(16):
    ax = plt.subplot(4, 4, i + 1)
    image = imread(database_base_path+'train_images/'+train['image_id'].iloc[i])
    aug_image,a = data_augment(image,label = 0)
    plt.imshow(tf.cast(aug_image,dtype = tf.int32))
    label = train['label'].iloc[i]
    plt.title(label)
    plt.axis("off")
print(time.time()-start_time)

In [None]:
from functools import partial

def decode_image(image_data):
    """
        1. Decode a JPEG-encoded image to a uint8 tensor.
        2. Cast tensor to float and normalizes (range between 0 and 1).
        3. Resize and reshape images to the expected size.
    """
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32)/255

    return image

def correct_shapes(image,label):
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, 3])
    return image,label

def random_crop(image,label):
    image = tf.image.random_crop(image, size= [HEIGHT,WIDTH,CHANNELS])
    return image,label

def process_path(file_path):
    name = get_name(file_path)
    img = tf.io.read_file(file_path)
    img = decode_image(img)
    return img, name

def get_name(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    name = parts[-1]
    return name

def read_tfrecord(example, labeled=True):
    """
        1. Parse data based on the 'TFREC_FORMAT' map.
        2. Decode image.
        3. If 'labeled' returns (image, label) if not (image, name).
    """
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'target': tf.io.FixedLenFeature([], tf.int64), 
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'image_name': tf.io.FixedLenFeature([], tf.string), 
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    if labeled:
        label_or_name = tf.cast(example['target'], tf.int32)
    else:
        label_or_name =  example['image_name']
    return image, label_or_name

def load_dataset(filenames, labeled=True, ordered=False):
    """
        Create a Tensorflow dataset from TFRecords.
    """
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda x: read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    return dataset

def one_hot_fn(x,y):
    y = tf.one_hot(y,N_CLASSES)
    tf.ensure_shape(y,[N_CLASSES])
    return x,y


def get_training_dataset(dataset, do_aug = True,repeat = False,is_one_hot = ONE_HOT): # trainingfiles changed to dataset
    #dataset = load_dataset(training_files, labeled=True)
    if do_aug:
        dataset = dataset.map(data_augment)
    if is_one_hot:
        dataset = dataset.map(one_hot_fn)
    dataset = dataset.shuffle(2048,reshuffle_each_iteration=True)
    if repeat: dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset(validation_files , ordered=False,is_crop = False,is_one_hot = ONE_HOT):
    dataset = load_dataset(validation_files, labeled=True, ordered=ordered) 
    if is_crop:
        dataset = dataset.map(random_crop)
    else:
        dataset = dataset.map(correct_shapes)
    if is_one_hot:
        dataset = dataset.map(one_hot_fn)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset


def get_test_dataset(files_path, shuffled=False, tta=True, extension='jpg'):
    dataset = tf.data.Dataset.list_files(f'{files_path}*{extension}', shuffle=shuffled)
    dataset = dataset.map(process_path, num_parallel_calls=AUTO)
    if tta:
        dataset = dataset.map(process_data, num_parallel_calls=AUTO)
        dataset = dataset.map(set_shapes, num_parallel_calls=AUTO)
    dataset = dataset.batch(1)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
LR_START = 5e-6
LR_MIN = 1e-5
LR_MAX = LEARNING_RATE
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 0
N_CYCLES = .5

def lrfn(epoch,change_on_step = False,step = 0,max_steps = 10):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        progress = (epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)
        lr = LR_MAX * (0.5 * (1.0 + tf.math.cos(math.pi * N_CYCLES * 2.0 * progress)))
        if LR_MIN is not None:
            lr = tf.math.maximum(LR_MIN, lr)
    if change_on_step:
        next_lr = lrfn(epoch+1)
        return lr+step/max_steps*(next_lr-lr) 
    return lr

rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]

sns.set(style='whitegrid')
fig, ax = plt.subplots(figsize=(20, 6))
plt.plot(rng, y)

print(f'{EPOCHS} total epochs and {len(train)//BATCH_SIZE} steps per epoch')
print(f'Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}')

In [None]:
def for_loop(num_iters, body, initial_args):
    """Runs a simple for-loop with given body and initial_args.
    Args:
    num_iters: Maximum number of iterations.
    body: Body of the for-loop.
    initial_args: Args to the body for the first iteration.
    Returns:
    Output of the final iteration.
    """
    for i in range(num_iters):
        if i == 0:
            outputs = body(*initial_args)
        else:
            outputs = body(*outputs)
    return outputs
def log_t(u, t):
    """Compute log_t for `u`."""

    def _internal_log_t(u, t):
        return (u**(1.0 - t) - 1.0) / (1.0 - t)

    return tf.cond(
      tf.equal(t, 1.0), lambda: tf.math.log(u),
      functools.partial(_internal_log_t, u, t))
def exp_t(u, t):
    """Compute exp_t for `u`."""

    def _internal_exp_t(u, t):
        return tf.nn.relu(1.0 + (1.0 - t) * u)**(1.0 / (1.0 - t))

    return tf.cond(
      tf.equal(t, 1.0), lambda: tf.math.exp(u),
      functools.partial(_internal_exp_t, u, t))

def compute_normalization_fixed_point(activations, t, num_iters=5):
    """Returns the normalization value for each example (t > 1.0).
    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (> 1.0 for tail heaviness).
    num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """

    mu = tf.math.reduce_max(activations, -1, keepdims=True)
    normalized_activations_step_0 = activations - mu
    shape_normalized_activations = tf.shape(normalized_activations_step_0)

    def iter_body(i, normalized_activations):
        logt_partition = tf.math.reduce_sum(
            exp_t(normalized_activations, t), -1, keepdims=True)
        normalized_activations_t = tf.reshape(
            normalized_activations_step_0 * tf.math.pow(logt_partition, 1.0 - t),
            shape_normalized_activations)
        return [i + 1, normalized_activations_t]

    _, normalized_activations_t = for_loop(num_iters, iter_body,
                                         [0, normalized_activations_step_0])
    logt_partition = tf.math.reduce_sum(
      exp_t(normalized_activations_t, t), -1, keepdims=True)
    return -log_t(1.0 / logt_partition, t) + mu

def compute_normalization_binary_search(activations, t, num_iters=10):
    """Returns the normalization value for each example (t < 1.0).
    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (< 1.0 for finite support).
    num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    mu = tf.math.reduce_max(activations, -1, keepdims=True)
    normalized_activations = activations - mu
    shape_activations = tf.shape(activations)
    effective_dim = tf.cast(
      tf.math.reduce_sum(
          tf.cast(
              tf.greater(normalized_activations, -1.0 / (1.0 - t)), tf.int32),
          -1,
          keepdims=True), tf.float32)
    shape_partition = tf.concat([shape_activations[:-1], [1]], 0)
    lower = tf.zeros(shape_partition)
    upper = -log_t(1.0 / effective_dim, t) * tf.ones(shape_partition)

    def iter_body(i, lower, upper):
        logt_partition = (upper + lower)/2.0
        sum_probs = tf.math.reduce_sum(exp_t(
            normalized_activations - logt_partition, t), -1, keepdims=True)
        update = tf.cast(tf.less(sum_probs, 1.0), tf.float32)
        lower = tf.reshape(lower * update + (1.0 - update) * logt_partition,
                           shape_partition)
        upper = tf.reshape(upper * (1.0 - update) + update * logt_partition,
                           shape_partition)
        return [i + 1, lower, upper]

    _, lower, upper = for_loop(num_iters, iter_body, [0, lower, upper])
    logt_partition = (upper + lower)/2.0
    return logt_partition + mu

def compute_normalization(activations, t, num_iters=5):
    """Returns the normalization value for each example.
    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature 2 (< 1.0 for finite support, > 1.0 for tail heaviness).
    num_iters: Number of iterations to run the method.
    Return: A tensor of same rank as activation with the last dimension being 1.
    """
    return tf.cond(
      tf.less(t, 1.0),
      functools.partial(compute_normalization_binary_search, activations, t,
                        num_iters),
      functools.partial(compute_normalization_fixed_point, activations, t,
                        num_iters))

def _internal_bi_tempered_logistic_loss(activations, labels, t1, t2):
    """Computes the Bi-Tempered logistic loss.
    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    labels: batch_size
    t1: Temperature 1 (< 1.0 for boundedness).
    t2: Temperature 2 (> 1.0 for tail heaviness).
    Returns:
    A loss tensor for robust loss.
    """
    if t2 == 1.0:
        normalization_constants = tf.math.log(
            tf.math.reduce_sum(tf.exp(activations), -1, keepdims=True))
        if t1 == 1.0:
            return normalization_constants + tf.math.reduce_sum(
                tf.multiply(labels, tf.log(labels + 1e-10) - activations), -1)
        else:
            shifted_activations = tf.math.exp(activations - normalization_constants)
            one_minus_t1 = (1.0 - t1)
            one_minus_t2 = 1.0
    else:
        one_minus_t1 = (1.0 - t1)
        one_minus_t2 = (1.0 - t2)
        normalization_constants = compute_normalization(
        activations, t2, num_iters=5)
        shifted_activations = tf.nn.relu(1.0 + one_minus_t2 *
                             (activations - normalization_constants))

    if t1 == 1.0:
        return tf.math.reduce_sum(
            tf.math.multiply(
            tf.math.log(labels + 1e-10) -
            tf.math.log(tf.math.pow(shifted_activations, 1.0 / one_minus_t2)), labels),
            -1)
    else:
        beta = 1.0 + one_minus_t1
        logt_probs = (tf.math.pow(shifted_activations, one_minus_t1 / one_minus_t2) -
          1.0) / one_minus_t1
        return tf.math.reduce_sum(
            tf.math.multiply(log_t(labels, t1) - logt_probs, labels) - 1.0 / beta *
            (tf.math.pow(labels, beta) -
            tf.math.pow(shifted_activations, beta / one_minus_t2)), -1)
    
def tempered_sigmoid(activations, t, num_iters=5):
    """Tempered sigmoid function.
    Args:
    activations: Activations for the positive class for binary classification.
    t: Temperature tensor > 0.0.
    num_iters: Number of iterations to run the method.
    Returns:
    A probabilities tensor.
    """
    t = tf.convert_to_tensor(t)
    input_shape = tf.shape(activations)
    activations_2d = tf.reshape(activations, [-1, 1])
    internal_activations = tf.concat(
      [tf.zeros_like(activations_2d), activations_2d], 1)
    normalization_constants = tf.cond(
      # pylint: disable=g-long-lambda
      tf.equal(t, 1.0),
      lambda: tf.math.log(
          tf.math.reduce_sum(tf.exp(internal_activations), -1, keepdims=True)),
      functools.partial(compute_normalization, internal_activations, t,
                        num_iters))
    internal_probabilities = exp_t(internal_activations - normalization_constants,
                                 t)
    one_class_probabilities = tf.split(internal_probabilities, 2, axis=1)[1]
    return tf.reshape(one_class_probabilities, input_shape)

def tempered_softmax(activations, t, num_iters=5):
    """Tempered softmax function.
    Args:
    activations: A multi-dimensional tensor with last dimension `num_classes`.
    t: Temperature tensor > 0.0.
    num_iters: Number of iterations to run the method.
    Returns:
    A probabilities tensor.
    """
    t = tf.convert_to_tensor(t)
    normalization_constants = tf.cond(
      tf.equal(t, 1.0),
      lambda: tf.math.log(tf.math.reduce_sum(tf.exp(activations), -1, keepdims=True)),
      functools.partial(compute_normalization, activations, t, num_iters))
    return exp_t(activations - normalization_constants, t)

def bi_tempered_logistic_loss(activations,
                              labels,
                              t1=T1,
                              t2=T2,
                              label_smoothing=LABEL_SMOOTH,
                              num_iters=5):
    #"""Bi-Tempered Logistic Loss with custom gradient.
    #Args:
    #  activations: A multi-dimensional tensor with last dimension `num_classes`.
    #  labels: A tensor with shape and dtype as activations.
    #  t1: Temperature 1 (< 1.0 for boundedness).
    #  t2: Temperature 2 (> 1.0 for tail heaviness, < 1.0 for finite support).
    #  label_smoothing: Label smoothing parameter between [0, 1).
    #  num_iters: Number of iterations to run the method.
    #Returns:
    #  A loss tensor.
    #"""
    with tf.name_scope('bitempered_logistic'):
        t1 = tf.convert_to_tensor(t1)
        t2 = tf.convert_to_tensor(t2)
        if label_smoothing > 0.0:
            num_classes = tf.cast(tf.shape(labels)[-1], tf.float32)
            labels = (
              1 - num_classes /
              (num_classes - 1) * label_smoothing) * labels + label_smoothing / (
                  num_classes - 1)

        @tf.custom_gradient
        def _custom_gradient_bi_tempered_logistic_loss(activations):
        #"""Bi-Tempered Logistic Loss with custom gradient.
        #Args:
        #activations: A multi-dimensional tensor with last dim `num_classes`.
        #Returns:
        #A loss tensor, grad.
        #"""
            with tf.name_scope('gradient_bitempered_logistic'):
                probabilities = tempered_softmax(activations, t2, num_iters)
                loss_values = tf.multiply(
                labels,
                log_t(labels + 1e-10, t1) - log_t(probabilities, t1)) - 1.0 / (2.0 - t1) * (tf.pow(labels, 2.0 - t1) - tf.pow(probabilities, 2.0 - t1))

                def grad(d_loss):
                    #"""Explicit gradient calculation.
                    #Args:
                    #d_loss: Infinitesimal change in the loss value.
                    #Returns:
                    #Loss gradient.
                    #"""
                    delta_probs = probabilities - labels
                    forget_factor = tf.math.pow(probabilities, t2 - t1)
                    delta_probs_times_forget_factor = tf.math.multiply(delta_probs,
                                                                forget_factor)
                    delta_forget_sum = tf.math.reduce_sum(
                      delta_probs_times_forget_factor, -1, keepdims=True)
                    escorts = tf.math.pow(probabilities, t2)
                    escorts = escorts / tf.math.reduce_sum(escorts, -1, keepdims=True)
                    derivative = delta_probs_times_forget_factor - tf.multiply(
                      escorts, delta_forget_sum)
                    return tf.multiply(d_loss, derivative)

            return loss_values, grad

    loss_values = tf.cond(tf.logical_and(tf.equal(t1, 1.0), tf.equal(t2, 1.0)),
                          functools.partial(
                              tf.nn.softmax_cross_entropy_with_logits,
                              labels=labels,
                              logits=activations),
                          functools.partial(
                              _custom_gradient_bi_tempered_logistic_loss,
                              activations))
    reduce_sum_last = lambda x: tf.math.reduce_sum(x, -1)
    loss_values = tf.cond(tf.logical_and(tf.equal(t1, 1.0), tf.equal(t2, 1.0)),
                          functools.partial(tf.identity, loss_values),
                          functools.partial(reduce_sum_last, loss_values))
    return loss_values

In [None]:
#Define the model
def model_fn(input_shape, N_CLASSES):
    with strategy.scope():
        input_image = L.Input(shape=input_shape, name='input_image')
        base_model = efn.EfficientNetB3(input_tensor=input_image, 
                                        include_top=False, 
                                        weights='imagenet', 
                                        pooling='avg')

        model = tf.keras.Sequential([
            base_model,
            L.Dropout(.25),
            L.Dense(N_CLASSES, activation='softmax', name='output')
        ])

        optimizer = optimizers.Adam(LR_START)
        lrschedule = LearningRateScheduler(lrfn, verbose=1) 
        model.compile(optimizer=optimizer, 
                      loss=keras.losses.CategoricalCrossentropy(label_smoothing= LABEL_SMOOTH), 
                      metrics=['categorical_accuracy'])

        return model
model = model_fn(input_shape=(HEIGHT,WIDTH,CHANNELS), N_CLASSES=N_CLASSES)

In [None]:
model.summary()

In [None]:
def train_step(x_batch_train, y_batch_train):
    with tf.GradientTape() as tape:

        # Run the forward pass of the layer.
        # The operations that the layer applies
        # to its inputs are going to be recorded
        # on the GradientTape.
        preds = model(x_batch_train, training=True)  # Logits for this minibatch

        # Compute the loss value for this minibatch.
        if tpu:
            loss_value = loss_fn(y_batch_train, preds)
        else:
            loss_value = bi_tempered_logistic_loss(preds,
                              y_batch_train,
                              t1=T1,
                              t2=T2,
                              label_smoothing=LABEL_SMOOTH,
                              num_iters=5)

    # Use the gradient tape to automatically retrieve
    # the gradients of the trainable variables with respect to the loss.
    grads = tape.gradient(loss_value, model.trainable_weights)

    # Run one step of gradient descent by updating
    # the value of the variables to minimize the loss.
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    epoch_accuracy.update_state(y_batch_train, preds)
    return loss_value

def test_step(x_test,y_test):
    preds = model(x_test, training=False)  # Logits for this minibatch
    # Compute the loss value for this minibatch.
    val_loss_value = loss_fn(y_test, preds)
    val_loss.update_state(val_loss_value)
    val_accuracy.update_state(y_test, preds)
    


In [None]:
if not tpu:
    print_time = True
    tf.config.run_functions_eagerly(True)
    step_print = 500
else:
    tf.tpu.experimental.initialize_tpu_system(tpu)
    print_time = True
    step_print = 25
import gc
import functools
models = []
histories = []
#
skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=seed)
# Loop for each model fold
for fold,(idxT, idxV) in enumerate(skf.split(np.arange(15))):
    @tf.function
    def distributed_train_step(x_batch_train, y_batch_train):
        per_replica_losses = strategy.run(train_step, args=(x_batch_train, y_batch_train))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,axis=0) / BATCH_SIZE / REPLICAS
    @tf.function
    def distributed_test_step(x_test, y_test):
        return strategy.run(test_step, args=(x_test,y_test,))
    K.clear_session()
    fold_time = time.time()
    print("Start of FOLD %d." % (fold))
    print("Training: {} Validation: {}".format(idxT,idxV))
    # Instantiate the model for this fold
    model = model_fn(input_shape=(HEIGHT,WIDTH,CHANNELS), N_CLASSES=N_CLASSES)
    # Instantiate fold history
    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []
    # Instantiate an optimizer.
    optimizer = optimizers.Adam(LR_START)
    
    
    #____________________DATASETS INITIALIZATION
    TRAIN_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/train_tfrecords/ld_train%.2i*.tfrec' % x for x in idxT])
    VALID_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/train_tfrecords/ld_train%.2i*.tfrec' % x for x in idxV])
    TRAIN_COUNT = count_data_items(TRAIN_FILENAMES)
    steps = math.floor(TRAIN_COUNT/BATCH_SIZE)
    np.random.shuffle(TRAIN_FILENAMES)
    train_dataset = get_training_dataset(load_dataset(TRAIN_FILENAMES,labeled = True),do_aug = True,is_one_hot = True)
    val_dataset = get_validation_dataset(VALID_FILENAMES,is_crop = True,is_one_hot = True)
    
    
    
    best_val_loss = 1e3
    fold_patience = 0
    with strategy.scope():
        loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True,reduction=tf.keras.losses.Reduction.NONE,label_smoothing= LABEL_SMOOTH)
        epoch_accuracy = tf.keras.metrics.CategoricalAccuracy()
        val_accuracy = tf.keras.metrics.CategoricalAccuracy()
        val_loss = tf.keras.metrics.Mean()
    for epoch in range(EPOCHS):
        step = 0
        loss = 0
        start_time = time.time()
        print("Start of epoch %d. LR: %f \n" % (epoch,optimizer.learning_rate))
        for x_batch_train,y_batch_train in train_dataset:
            optimizer.learning_rate = lrfn(epoch = epoch,change_on_step = True,step = step,max_steps = steps)
            with strategy.scope():
                loss_value = distributed_train_step(
                    x_batch_train,
                    y_batch_train)
            # Log every step_print batches.
            loss += loss_value
            step += 1
            if step % step_print == 0 or step == steps or step == 1:
                
                if print_time or step == steps - 1:
                    print("Time taken this epoch: %.2fs per step: %.2fs current lr: %f" % (
                        time.time() - start_time,(time.time() - start_time)/(step),
                        optimizer.learning_rate))
                print(
                    "Training loss at step %d/%d: %.4f  Accuracy: %.4f"
                    % ((step),steps,loss/step,epoch_accuracy.result()))                                                                        
        val_time = time.time()
        for x_val,y_val in val_dataset:
            with strategy.scope():
                distributed_test_step(x_val,y_val)
                train_loss_history.append(loss/step)
                train_acc_history.append(epoch_accuracy.result())
                val_loss_history.append(val_loss.result())
                val_acc_history.append(val_accuracy.result())
        print("\nValidation for epoch %d:    Time taken for validation: %.2fs"%(epoch,time.time() - val_time))
        print("val_acc : %.4f val_loss : %.4f"%(val_accuracy.result(),val_loss.result()))
        if val_loss.result().numpy() < best_val_loss:
            best_val_loss = val_loss.result().numpy()
            path  = f'Best_Save_EFFNETB3:_Fold:{fold}.hdf5'
            model.save(path)
            print('Model Saved \n')
            fold_patience = 0
        else:
            fold_patience += 1
            if fold_patience == ES_PATIENCE:
                print('Early Stepping due to model not improving for %f epochs'.format(ES_PATIENCE))
                break
        
        print("\nTime taken for this fold so far : %.2fs \n"%(time.time() - fold_time))
        
        with strategy.scope():
            epoch_accuracy.reset_states()
            val_accuracy.reset_states()
            val_loss.reset_states()
    models.append(model)
    histories.append([train_loss_history,train_acc_history,val_loss_history,val_acc_history])
    path  = f'Cassava_Model_EFFNETB3:_Fold:{fold}.hdf5'
    model.save(path)
    #del(model,optimizer,loss_fn,history)

models = []
histories = []
#
skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=seed)
# Loop for each model fold
for fold,(idxT, idxV) in enumerate(skf.split(np.arange(15))):
    optimizer = optimizers.Adam(LR_START)
    # Instantiate a loss function.
    with strategy.scope():
        model = model_fn(input_shape=(HEIGHT,WIDTH,CHANNELS), N_CLASSES=N_CLASSES)
    TRAIN_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/train_tfrecords/ld_train%.2i*.tfrec' % x for x in idxT])
    VALID_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/train_tfrecords/ld_train%.2i*.tfrec' % x for x in idxV])
    TRAIN_COUNT = count_data_items(TRAIN_FILENAMES)
    steps = math.floor(TRAIN_COUNT/BATCH_SIZE)
    np.random.shuffle(TRAIN_FILENAMES)
    train_dataset = get_training_dataset(load_dataset(TRAIN_FILENAMES,labeled = True),do_aug = True,is_one_hot = True)
    val_dataset = get_validation_dataset(VALID_FILENAMES,is_crop = True,is_one_hot = True)
    es = EarlyStopping(monitor='val_loss', mode='min', 
                       patience=ES_PATIENCE, restore_best_weights=True, verbose=1)
    history = model.fit(x=train_dataset, 
                        validation_data=val_dataset, 
                        steps_per_epoch=steps, 
                        callbacks=[es, LearningRateScheduler(lrfn, verbose=0)], 
                        epochs=EPOCHS,  
                        verbose=1).history

In [None]:
plt.plot(histories[0][0])
plt.plot(histories[0][2])

In [None]:
plt.plot(histories[0][1])
plt.plot(histories[0][3])

In [None]:
plt.plot(histories[1][0])
plt.plot(histories[1][2])

In [None]:
plt.plot(histories[1][1])
plt.plot(histories[1][3])

In [None]:
plt.plot(histories[2][0])
plt.plot(histories[2][2])

In [None]:
plt.plot(histories[2][1])
plt.plot(histories[2][3])