In [None]:
!pip install --quiet efficientnet

In [None]:
!pip install keras-swa
from swa.tfkeras import SWA

In [None]:
!pip install tensorflow-addons=='0.9.1'
import tensorflow_addons as tfa

In [None]:
!pip install git+https://github.com/qubvel/classification_models.git

In [None]:
import math, os, re, warnings, random
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from kaggle_datasets import KaggleDatasets
from sklearn.utils import class_weight
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
from tensorflow.keras import optimizers, applications, Sequential, losses, metrics
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, LearningRateScheduler , ReduceLROnPlateau
import efficientnet.tfkeras as efn
from sklearn.model_selection import StratifiedKFold

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 0
seed_everything(seed)
warnings.filterwarnings('ignore')

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
BATCH_SIZE = 16 * REPLICAS
LEARNING_RATE = 3e-5 * REPLICAS
EPOCHS = 40
HEIGHT = 512
WIDTH = 512
CHANNELS = 3
N_CLASSES = 5
ES_PATIENCE = 10
N_FOLDS = 5
MODEL_NAME = 'EfficientNet4'
focal_loss = 0
SWA = False
label_smoothing = 0.005
cutmix_rate = 0.2
mixup_rate = 0.2
gridmask_rate = 0
img_size = 512

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)


database_base_path = '/kaggle/input/cassava-leaf-disease-classification/'
train = pd.read_csv(database_base_path + 'train.csv')

print('Train samples: %d' % len(train))

# GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification') # Original dataset
# GCS_PATH = KaggleDatasets().get_gcs_path(f'cassava-leaf-disease-tfrecords-{HEIGHT}x{WIDTH}') # Only resized
GCS_PATH = KaggleDatasets().get_gcs_path(f'cassava-leaf-disease-tfrecords-center-{HEIGHT}x{WIDTH}') # Center croped
GCS_PATH_EXT = KaggleDatasets().get_gcs_path('cassava-leaf-disease-tfrecords-classes-ext-512x512') 
# TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train_tfrecords/*.tfrec') # Original TFRecords
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/*.tfrec')
EXT_FILENAMES = tf.io.gfile.glob(GCS_PATH_EXT + '/*.tfrec')
TRAINING_FILENAMES += EXT_FILENAMES
NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
print(f"GCS: train images: {NUM_TRAINING_IMAGES}")

display(train.head())


CLASSES = ['Cassava Bacterial Blight', 
           'Cassava Brown Streak Disease', 
           'Cassava Green Mottle', 
           'Cassava Mosaic Disease', 
           'Healthy']

# CutMix

In [None]:
import sys
def cutmix(image, label, PROBABILITY = cutmix_rate):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    DIM = img_size
    
    imgs = []; labs = []
    for j in range(BATCH_SIZE):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,BATCH_SIZE),tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.round(tf.math.maximum(0,y-WIDTH//2))
        yb = tf.math.round(tf.math.minimum(DIM,y+WIDTH//2))
        xa = tf.math.round(tf.math.maximum(0,x-WIDTH//2))
        xb = tf.math.round(tf.math.minimum(DIM,x+WIDTH//2))
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast((xb-xa)*(yb-ya)/DIM/DIM,tf.float32)
        if len(label.shape)==1:
            lab1 = tf.one_hot(label[j],N_CLASSES)
            lab2 = tf.one_hot(label[k],N_CLASSES)
        else:
            lab1 = label[j,]
            lab2 = label[k,]
        
        tf.print("{} {} {}".format(a,lab1,lab2))
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(BATCH_SIZE,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(BATCH_SIZE,N_CLASSES))
    return image2,label2

# MixUp

In [None]:
def mixup(image, label, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with mixup applied
    DIM = img_size
    
    imgs = []; labs = []
    for j in range(BATCH_SIZE):
        # DO MIXUP WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.float32)
        # CHOOSE RANDOM
        k = tf.cast( tf.random.uniform([],0,BATCH_SIZE),tf.int32)
        a = tf.random.uniform([],0,1)*P # this is beta dist with alpha=1.0
        # MAKE MIXUP IMAGE
        img1 = image[j,]
        img2 = image[k,]
        imgs.append((1-a)*img1 + a*img2)
        # MAKE CUTMIX LABEL
        if len(label.shape)==1:
            lab1 = tf.one_hot(label[j],N_CLASSES)
            lab2 = tf.one_hot(label[k],N_CLASSES)
        else:
            lab1 = label[j,]
            lab2 = label[k,]
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(BATCH_SIZE,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(BATCH_SIZE,N_CLASSES))
    return image2,label2

# GridMask

In [None]:
def transform(image, inv_mat, image_shape):
    h, w, c = image_shape
    cx, cy = w//2, h//2
    new_xs = tf.repeat( tf.range(-cx, cx, 1), h)
    new_ys = tf.tile( tf.range(-cy, cy, 1), [w])
    new_zs = tf.ones([h*w], dtype=tf.int32)
    old_coords = tf.matmul(inv_mat, tf.cast(tf.stack([new_xs, new_ys, new_zs]), tf.float32))
    old_coords_x, old_coords_y = tf.round(old_coords[0, :] + w//2), tf.round(old_coords[1, :] + h//2)
    clip_mask_x = tf.logical_or(old_coords_x<0, old_coords_x>w-1)
    clip_mask_y = tf.logical_or(old_coords_y<0, old_coords_y>h-1)
    clip_mask = tf.logical_or(clip_mask_x, clip_mask_y)
    old_coords_x = tf.boolean_mask(old_coords_x, tf.logical_not(clip_mask))
    old_coords_y = tf.boolean_mask(old_coords_y, tf.logical_not(clip_mask))
    new_coords_x = tf.boolean_mask(new_xs+cx, tf.logical_not(clip_mask))
    new_coords_y = tf.boolean_mask(new_ys+cy, tf.logical_not(clip_mask))
    old_coords = tf.cast(tf.stack([old_coords_y, old_coords_x]), tf.int32)
    new_coords = tf.cast(tf.stack([new_coords_y, new_coords_x]), tf.int64)
    rotated_image_values = tf.gather_nd(image, tf.transpose(old_coords))
    rotated_image_channel = list()
    for i in range(c):
        vals = rotated_image_values[:,i]
        sparse_channel = tf.SparseTensor(tf.transpose(new_coords), vals, [h, w])
        rotated_image_channel.append(tf.sparse.to_dense(sparse_channel, default_value=0, validate_indices=False))
    return tf.transpose(tf.stack(rotated_image_channel), [1,2,0])

In [None]:
def random_rotate(image, angle, image_shape):
    def get_rotation_mat_inv(angle):
        # transform to radian
        angle = math.pi * angle / 180
        cos_val = tf.math.cos(angle)
        sin_val = tf.math.sin(angle)
        one = tf.constant([1], tf.float32)
        zero = tf.constant([0], tf.float32)
        rot_mat_inv = tf.concat([cos_val, sin_val, zero, -sin_val, cos_val, zero, zero, zero, one], axis=0)
        rot_mat_inv = tf.reshape(rot_mat_inv, [3,3])
        return rot_mat_inv
    angle = float(angle) * tf.random.normal([1],dtype='float32')
    rot_mat_inv = get_rotation_mat_inv(angle)
    return transform(image, rot_mat_inv, image_shape)

In [None]:
def GridMask(image_height, image_width, d1, d2, rotate_angle=1, ratio=0.5):
    h, w = image_height, image_width
    hh = int(np.ceil(np.sqrt(h*h+w*w)))
    hh = hh+1 if hh%2==1 else hh
    d = tf.random.uniform(shape=[], minval=d1, maxval=d2, dtype=tf.int32)
    l = tf.cast(tf.cast(d,tf.float32)*ratio+0.5, tf.int32)

    st_h = tf.random.uniform(shape=[], minval=0, maxval=d, dtype=tf.int32)
    st_w = tf.random.uniform(shape=[], minval=0, maxval=d, dtype=tf.int32)

    y_ranges = tf.range(-1 * d + st_h, -1 * d + st_h + l)
    x_ranges = tf.range(-1 * d + st_w, -1 * d + st_w + l)

    for i in range(0, hh//d+1):
        s1 = i * d + st_h
        s2 = i * d + st_w
        y_ranges = tf.concat([y_ranges, tf.range(s1,s1+l)], axis=0)
        x_ranges = tf.concat([x_ranges, tf.range(s2,s2+l)], axis=0)

    x_clip_mask = tf.logical_or(x_ranges < 0 , x_ranges > hh-1)
    y_clip_mask = tf.logical_or(y_ranges < 0 , y_ranges > hh-1)
    clip_mask = tf.logical_or(x_clip_mask, y_clip_mask)

    x_ranges = tf.boolean_mask(x_ranges, tf.logical_not(clip_mask))
    y_ranges = tf.boolean_mask(y_ranges, tf.logical_not(clip_mask))

    hh_ranges = tf.tile(tf.range(0,hh), [tf.cast(tf.reduce_sum(tf.ones_like(x_ranges)), tf.int32)])
    x_ranges = tf.repeat(x_ranges, hh)
    y_ranges = tf.repeat(y_ranges, hh)

    y_hh_indices = tf.transpose(tf.stack([y_ranges, hh_ranges]))
    x_hh_indices = tf.transpose(tf.stack([hh_ranges, x_ranges]))

    y_mask_sparse = tf.SparseTensor(tf.cast(y_hh_indices, tf.int64),  tf.zeros_like(y_ranges), [hh, hh])
    y_mask = tf.sparse.to_dense(y_mask_sparse, 1, False)

    x_mask_sparse = tf.SparseTensor(tf.cast(x_hh_indices, tf.int64), tf.zeros_like(x_ranges), [hh, hh])
    x_mask = tf.sparse.to_dense(x_mask_sparse, 1, False)

    mask = tf.expand_dims( tf.clip_by_value(x_mask + y_mask, 0, 1), axis=-1)

    mask = random_rotate(mask, rotate_angle, [hh, hh, 1])
    mask = tf.image.crop_to_bounding_box(mask, (hh-h)//2, (hh-w)//2, image_height, image_width)

    return mask

In [None]:
def apply_grid_mask(image, image_shape, PROBABILITY = gridmask_rate):
    AugParams = {
        'd1' : 100,
        'd2': 160,
        'rotate' : 45,
        'ratio' : 0.3
    }
    
        
    mask = GridMask(image_shape[0], image_shape[1], AugParams['d1'], AugParams['d2'], AugParams['rotate'], AugParams['ratio'])
    if image_shape[-1] == 3:
        mask = tf.concat([mask, mask, mask], axis=-1)
        mask = tf.cast(mask,tf.float32)
        P = tf.cast(tf.random.uniform([], 0, 1) <= PROBABILITY, tf.int32)
    if P==1:
        return image*mask
    else:
        return image

def gridmask(img_batch, label_batch):
    return apply_grid_mask(img_batch, (img_size,img_size, 3)), label_batch

# More data augs :

In [None]:
AugParams = {
    'scale_factor':0.5,
    'scale_prob':0.5,
    'rot_range':90,
    'rot_prob':0.5,
    'blockout_sl':0.1,
    'blockout_sh':0.2,
    'blockout_rl':0.4,
    'blockout_prob':0.5,
    'blur_ksize':3,
    'blur_sigma':1,
    'blur_prob':0.5
}
IMG_DIM = (HEIGHT,WIDTH)

In [None]:
def image_rotate(image, angle):

    if len(image.get_shape().as_list()) != 3:
        raise ValueError('`image_rotate` only support image with 3 dimension(h, w, c)`')

    angle = tf.cast(angle, tf.float32)
    h, w, c = HEIGHT, WIDTH, 3
    cy, cx = h//2, w//2

    ys = tf.range(h)
    xs = tf.range(w)

    ys_vec = tf.tile(ys, [w])
    xs_vec = tf.reshape( tf.tile(xs, [h]), [h,w] )
    xs_vec = tf.reshape( tf.transpose(xs_vec, [1,0]), [-1])

    ys_vec_centered, xs_vec_centered = ys_vec - cy, xs_vec - cx
    new_coord_centered = tf.cast(tf.stack([ys_vec_centered, xs_vec_centered]), tf.float32)

    inv_rot_mat = tf.reshape( tf.dynamic_stitch([0,1,2,3], [tf.cos(angle), tf.sin(angle), -tf.sin(angle), tf.cos(angle)]), [2,2])
    old_coord_centered = tf.matmul(inv_rot_mat, new_coord_centered)

    old_ys_vec_centered, old_xs_vec_centered = old_coord_centered[0,:], old_coord_centered[1,:]
    old_ys_vec = tf.cast( tf.round(old_ys_vec_centered+cy), tf.int32)
    old_xs_vec = tf.cast( tf.round(old_xs_vec_centered+cx), tf.int32)

    outside_ind = tf.logical_or( tf.logical_or(old_ys_vec > h-1 , old_ys_vec < 0), tf.logical_or(old_xs_vec > w-1 , old_xs_vec<0))

    old_ys_vec = tf.boolean_mask(old_ys_vec, tf.logical_not(outside_ind))
    old_xs_vec = tf.boolean_mask(old_xs_vec, tf.logical_not(outside_ind))

    ys_vec = tf.boolean_mask(ys_vec, tf.logical_not(outside_ind))
    xs_vec = tf.boolean_mask(xs_vec, tf.logical_not(outside_ind))

    old_coord = tf.cast(tf.transpose(tf.stack([old_ys_vec, old_xs_vec]), [1,0]), tf.int32)
    new_coord = tf.cast(tf.transpose(tf.stack([ys_vec, xs_vec]), [1,0]), tf.int64)

    channel_vals = tf.split(image, c, axis=-1)
    rotated_channel_vals = list()
    for channel_val in channel_vals:
        rotated_channel_val = tf.gather_nd(channel_val, old_coord)

        sparse_rotated_channel_val = tf.SparseTensor(new_coord, tf.squeeze(rotated_channel_val,axis=-1), [h, w])
        rotated_channel_vals.append(tf.sparse.to_dense(sparse_rotated_channel_val, default_value=0, validate_indices=False))

    rotated_image = tf.transpose(tf.stack(rotated_channel_vals), [1, 2, 0])
    return rotated_image
    
def random_blockout(img, sl=0.1, sh=0.2, rl=0.4):

    h, w, c = WIDTH, HEIGHT, 3
    origin_area = tf.cast(h*w, tf.float32)

    e_size_l = tf.cast(tf.round(tf.sqrt(origin_area * sl * rl)), tf.int32)
    e_size_h = tf.cast(tf.round(tf.sqrt(origin_area * sh / rl)), tf.int32)

    e_height_h = tf.minimum(e_size_h, h)
    e_width_h = tf.minimum(e_size_h, w)

    erase_height = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_height_h, dtype=tf.int32)
    erase_width = tf.random.uniform(shape=[], minval=e_size_l, maxval=e_width_h, dtype=tf.int32)

    erase_area = tf.zeros(shape=[erase_height, erase_width, c])
    erase_area = tf.cast(erase_area, tf.uint8)

    pad_h = h - erase_height
    pad_top = tf.random.uniform(shape=[], minval=0, maxval=pad_h, dtype=tf.int32)
    pad_bottom = pad_h - pad_top

    pad_w = w - erase_width
    pad_left = tf.random.uniform(shape=[], minval=0, maxval=pad_w, dtype=tf.int32)
    pad_right = pad_w - pad_left

    erase_mask = tf.pad([erase_area], [[0,0],[pad_top, pad_bottom], [pad_left, pad_right], [0,0]], constant_values=1)
    erase_mask = tf.squeeze(erase_mask, axis=0)
    erased_img = tf.multiply(tf.cast(img,tf.float32), tf.cast(erase_mask, tf.float32))

    return tf.cast(erased_img, img.dtype)
    
def zoom_out(x, scale_factor):

    resize_x = tf.random.uniform(shape=[], minval=tf.cast(IMG_DIM[1]//(1/scale_factor), tf.int32), maxval=IMG_DIM[1], dtype=tf.int32)
    resize_y = tf.random.uniform(shape=[], minval=tf.cast(IMG_DIM[0]//(1/scale_factor), tf.int32), maxval=IMG_DIM[0], dtype=tf.int32)
    top_pad = (IMG_DIM[0] - resize_y) // 2
    bottom_pad = IMG_DIM[0] - resize_y - top_pad
    left_pad = (IMG_DIM[1] - resize_x ) // 2
    right_pad = IMG_DIM[1] - resize_x - left_pad
        
    x = tf.image.resize(x, (resize_y, resize_x))
    x = tf.pad([x], [[0,0], [top_pad, bottom_pad], [left_pad, right_pad], [0,0]])
    x = tf.image.resize(x, IMG_DIM)
    return tf.squeeze(x, axis=0)
    
def zoom_in(x, scale_factor):

    scales = list(np.arange(0.5, 1.0, 0.05))
    boxes = np.zeros((len(scales),4))
            
    for i, scale in enumerate(scales):
        x_min = y_min = 0.5 - (0.5*scale)
        x_max = y_max = 0.5 + (0.5*scale)
        boxes[i] = [x_min, y_min, x_max, y_max]
        
    def random_crop(x):
        crop = tf.image.crop_and_resize([x], boxes=boxes, box_indices=np.zeros(len(boxes)), crop_size=IMG_DIM)
        return crop[tf.random.uniform(shape=[], minval=0, maxval=len(scales), dtype=tf.int32)]
        
    return random_crop(x)

def gaussian_blur(img, ksize=5, sigma=1):
    
    def gaussian_kernel(size=3, sigma=1):

        x_range = tf.range(-(size-1)//2, (size-1)//2 + 1, 1)
        y_range = tf.range((size-1)//2, -(size-1)//2 - 1, -1)

        xs, ys = tf.meshgrid(x_range, y_range)
        kernel = tf.exp(-(xs**2 + ys**2)/(2*(sigma**2))) / (2*np.pi*(sigma**2))
        return tf.cast( kernel / tf.reduce_sum(kernel), tf.float32)
    
    kernel = gaussian_kernel(ksize, sigma)
    kernel = tf.expand_dims(tf.expand_dims(kernel, axis=-1), axis=-1)
    
    r, g, b = tf.split(img, [1,1,1], axis=-1)
    r_blur = tf.nn.conv2d([r], kernel, [1,1,1,1], 'SAME')
    g_blur = tf.nn.conv2d([g], kernel, [1,1,1,1], 'SAME')
    b_blur = tf.nn.conv2d([b], kernel, [1,1,1,1], 'SAME')

    blur_image = tf.concat([r_blur, g_blur, b_blur], axis=-1)
    return tf.squeeze(blur_image, axis=0)

In [None]:
def augmentation(image,label):
    image = tf.cast(image, tf.float32)
    #Gaussian blur
    if tf.random.uniform(shape=[], minval=0.0, maxval=1.0) > AugParams['blur_prob']:
        image = gaussian_blur(image, AugParams['blur_ksize'], AugParams['blur_sigma'])
    
    #Random block out
    if tf.random.uniform(shape=[], minval=0.0, maxval=1.0) > AugParams['blockout_prob']:
        image = random_blockout(image, AugParams['blockout_sl'], AugParams['blockout_sh'], AugParams['blockout_rl'])
        
    #Random scale
    if tf.random.uniform(shape=[], minval=0.0, maxval=1.0) > AugParams['scale_prob']:
        if tf.random.uniform(shape=[], minval=0.0, maxval=1.0) > 0.5:
            image = zoom_in(image, AugParams['scale_factor'])
        else:
            image = zoom_out(image, AugParams['scale_factor'])
    #Random rotate
    if tf.random.uniform(shape=[], minval=0.0, maxval=1.0) > AugParams['rot_prob']:
        angle = tf.random.uniform(shape=[], minval=-AugParams['rot_range'], maxval=AugParams['rot_range'], dtype=tf.int32)
        image = image_rotate(image,angle)
    
    return tf.cast(image, tf.uint8),label

In [None]:
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Shear
    if p_shear > .2:
        if p_shear > .6:
            image = transform_shear(image, HEIGHT, shear=20.)
        else:
            image = transform_shear(image, HEIGHT, shear=-20.)
    # Rotation
    if p_rotation > .2:
        if p_rotation > .6:
            image = transform_rotation(image, HEIGHT, rotation=45.)
        else:
            image = transform_rotation(image, HEIGHT, rotation=-45.)
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.6), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])

    return image, label

In [None]:
# data augmentation @cdeotte kernel: https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

In [None]:
# Datasets utility functions
def to_float32_2(image, label):
    max_val = tf.reduce_max(label, axis=-1,keepdims=True)
    cond = tf.equal(label, max_val)
    label = tf.where(cond, tf.ones_like(label), tf.zeros_like(label))
    return tf.cast(image, tf.float32), tf.cast(label, tf.int32)

def to_float32(image, label):
    return tf.cast(image, tf.float32), label


def one_hot(image,label):
    NUMCLASSES = len(CLASSES)
    return image,tf.one_hot(label,NUMCLASSES)

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
                      
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, 3])
    return image

def read_tfrecord(example, labeled=True):
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'target': tf.io.FixedLenFeature([], tf.int64), 
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'image_name': tf.io.FixedLenFeature([], tf.string), 
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    if labeled:
        label_or_name = tf.cast(example['target'], tf.int32)
    else:
        label_or_name =  example['image_name']
    return image, label_or_name

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda x: read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    return dataset

def get_dataset(FILENAMES, labeled=True, ordered=False, repeated=False, augment=False,valid=False):
    dataset = load_dataset(FILENAMES, labeled=labeled, ordered=ordered)
    if augment:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
        #dataset = dataset.map(augmentation, num_parallel_calls=AUTO)
    if repeated:
        dataset = dataset.repeat()
    if not ordered:
        dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    if valid :
        dataset = dataset.map(one_hot,num_parallel_calls=AUTO)
    else :
        if cutmix_rate :
            dataset = dataset.map(cutmix, num_parallel_calls=AUTO)
        if mixup_rate :
            dataset = dataset.map(mixup, num_parallel_calls=AUTO)
        if gridmask_rate :
            dataset = dataset.map(gridmask, num_parallel_calls=AUTO)

    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
# Visualization utility functions
np.set_printoptions(threshold=15, linewidth=80)

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)

def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()
    
# Visualize model predictions
def dataset_to_numpy_util(dataset, N):
    dataset = dataset.unbatch().batch(N)
    for images, labels in dataset:
        numpy_images = images.numpy()
        numpy_labels = labels.numpy()
        break;  
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    label = np.argmax(label, axis=-1)
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(label, str(correct), ', shoud be ' if not correct else '',
                                correct_label if not correct else ''), correct

def display_one_flower_eval(image, title, subplot, red=False):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image)
    plt.title(title, fontsize=14, color='red' if red else 'black')
    return subplot+1

def display_9_images_with_predictions(images, predictions, labels):
    subplot=331
    plt.figure(figsize=(13,13))
    for i, image in enumerate(images):
        title, correct = title_from_label_and_target(predictions[i], labels[i])
        subplot = display_one_flower_eval(image, title, subplot, not correct)
        if i >= 8:
            break;
              
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()
    
    
# Model evaluation
def plot_metrics(history):
    metric_list = [m for m in list(history.keys()) if m is not 'lr']
    size = len(metric_list)//2
    fig, axes = plt.subplots(size, 1, sharex='col', figsize=(20, size * 4))
    if size > 1:
        axes = axes.flatten()
    else:
        axes = [axes]
    
    for index in range(len(metric_list)//2):
        metric_name = metric_list[index]
        val_metric_name = metric_list[index+size]
        axes[index].plot(history[metric_name], label='Train %s' % metric_name)
        axes[index].plot(history[val_metric_name], label='Validation %s' % metric_name)
        axes[index].legend(loc='best', fontsize=16)
        axes[index].set_title(metric_name)
        if 'loss' in metric_name:
            axes[index].axvline(np.argmin(history[metric_name]), linestyle='dashed')
            axes[index].axvline(np.argmin(history[val_metric_name]), linestyle='dashed', color='orange')
        else:
            axes[index].axvline(np.argmax(history[metric_name]), linestyle='dashed')
            axes[index].axvline(np.argmax(history[val_metric_name]), linestyle='dashed', color='orange')

    plt.xlabel('Epochs', fontsize=16)
    sns.despine()
    plt.show()

In [None]:
train_dataset = get_dataset(TRAINING_FILENAMES, ordered=True)
train_iter = iter(train_dataset.unbatch().batch(20))

'''display_batch_of_images(next(train_iter))
display_batch_of_images(next(train_iter))'''

In [None]:
label_count = train.groupby('label', as_index=False).count()
label_count.rename(columns={'image_id': 'Count', 'label': 'Label'}, inplace=True)
label_count['Label'] = label_count['Label'].apply(lambda x: CLASSES[x])

fig, ax = plt.subplots(1, 1, figsize=(18, 10))
ax = sns.barplot(x=label_count['Count'], y=label_count['Label'], palette='viridis')
ax.tick_params(labelsize=16)

plt.show()

In [None]:
LR_START = 1e-8
LR_MIN = 1e-8
LR_MAX = LEARNING_RATE
LR_RAMPUP_EPOCHS = 3
LR_SUSTAIN_EPOCHS = 0
N_CYCLES = .5


def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        progress = (epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)
        lr = LR_MAX * (0.5 * (1.0 + tf.math.cos(math.pi * N_CYCLES * 2.0 * progress)))
        if LR_MIN is not None:
            lr = tf.math.maximum(LR_MIN, lr)
            
    return lr

rng = [i for i in range(EPOCHS)]
y = [lrfn(x) for x in rng]

sns.set(style='whitegrid')
fig, ax = plt.subplots(figsize=(20, 6))
plt.plot(rng, y)

print(f'{EPOCHS} total epochs and {NUM_TRAINING_IMAGES//BATCH_SIZE} steps per epoch')
print(f'Learning rate schedule: {y[0]:.3g} to {max(y):.3g} to {y[-1]:.3g}')

In [None]:
def model_fn(input_shape, N_CLASSES):
    input_image = L.Input(shape=input_shape, name='input_image')
    base_model = efn.EfficientNetB4(input_tensor=input_image, 
                                    include_top=False, 
                                    weights='noisy-student', 
                                    pooling='avg')

    model = tf.keras.Sequential([
        base_model,
        L.Dropout(.25),
        L.Dense(N_CLASSES, activation='softmax', name='output')
    ])

    optimizer = optimizers.Adam(lr=LEARNING_RATE)
    model.compile(optimizer=optimizer, 
                  loss=losses.SparseCategoricalCrossentropy(), 
                  metrics=['sparse_categorical_accuracy'])
    
    return model

# Focal Loss 

In [None]:
from tensorflow.keras import backend as K

import dill


def binary_focal_loss(gamma=2., alpha=.25):
    """
    Binary form of focal loss.
      FL(p_t) = -alpha * (1 - p_t)**gamma * log(p_t)
      where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    def binary_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred:  A tensor resulting from a sigmoid
        :return: Output tensor.
        """
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))

        epsilon = K.epsilon()
        # clip to prevent NaN's and Inf's
        pt_1 = K.clip(pt_1, epsilon, 1. - epsilon)
        pt_0 = K.clip(pt_0, epsilon, 1. - epsilon)

        return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1)) \
               -K.sum((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0))

    return binary_focal_loss_fixed


def categorical_focal_loss(gamma=2., alpha=.25):
    """
    Softmax version of focal loss.
           m
      FL = ∑  -alpha * (1 - p_o,c)^gamma * y_o,c * log(p_o,c)
          c=1
      where m = number of classes, c = class and o = observation
    Parameters:
      alpha -- the same as weighing factor in balanced cross entropy
      gamma -- focusing parameter for modulating factor (1-p)
    Default value:
      gamma -- 2.0 as mentioned in the paper
      alpha -- 0.25 as mentioned in the paper
    References:
        Official paper: https://arxiv.org/pdf/1708.02002.pdf
        https://www.tensorflow.org/api_docs/python/tf/keras/backend/categorical_crossentropy
    Usage:
     model.compile(loss=[categorical_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    def categorical_focal_loss_fixed(y_true, y_pred):
        """
        :param y_true: A tensor of the same shape as `y_pred`
        :param y_pred: A tensor resulting from a softmax
        :return: Output tensor.
        """

        # Scale predictions so that the class probas of each sample sum to 1
        y_pred /= K.sum(y_pred, axis=-1, keepdims=True)

        # Clip the prediction value to prevent NaN's and Inf's
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)

        # Calculate Cross Entropy
        cross_entropy = -y_true * K.log(y_pred)

        # Calculate Focal Loss
        loss = alpha * K.pow(1 - y_pred, gamma) * cross_entropy

        # Sum the losses in mini_batch
        return K.sum(loss, axis=1)

    return categorical_focal_loss_fixed

In [None]:
from tensorflow.keras.applications import DenseNet121, DenseNet201 , DenseNet169
from tensorflow.keras.applications import vgg16
from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications import ResNet101 , ResNet101V2 , ResNet152 , ResNet152V2 , ResNet50V2 , ResNet50
from tensorflow.keras.applications import MobileNet , MobileNetV2
from tensorflow.keras.applications import InceptionResNetV2
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.losses import CategoricalCrossentropy
from classification_models.keras import Classifiers

def get_model_generalized(name,input_shape,N_CLASSES):
    
    if name == 'EfficientNet0' :
        base_model = efn.EfficientNetB0(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
    elif name == 'EfficientNet1' :
        base_model = efn.EfficientNetB1(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
    elif name == 'EfficientNet2' :
        base_model = efn.EfficientNetB2(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
    elif name == 'EfficientNet4' :
        base_model = efn.EfficientNetB4(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
    elif name == 'EfficientNet5' :
        base_model = efn.EfficientNetB5(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
        #for layer in base_model.layers :
        #    layer.trainable = False
    elif name == 'EfficientNet6' :
        base_model = efn.EfficientNetB6(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
    elif name == 'EfficientNet7' :
        base_model = efn.EfficientNetB7(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
        #base_model.trainable = True
        #for layer in base_model.layers[:-20] :
        #    layer.trainable = True
    elif name == 'EfficientNet3' :
        base_model = efn.EfficientNetB3(weights='noisy-student',
                                        include_top = False,
                                        #pooling='avg',
                                        input_shape=input_shape
                                       )
        #base_model.trainable = True
        #for layer in base_model.layers[:-25] :
        #    layer.trainable = True        
            
    elif name == 'DenseNet201' :
        base_model = DenseNet201(weights='imagenet',include_top=False,input_shape=input_shape)
    elif name == 'DenseNet169' :
        base_model = DenseNet169(weights='imagenet',include_top=False,input_shape=input_shape)
        #base_model.trainable = True
    elif name == 'DenseNet121' :
        base_model = DenseNet121(weights='imagenet',include_top=False,input_shape=input_shape)
        #for layer in base_model.layers[:-25] :
        #    layer.trainable = True
    elif name == 'MobileNet' :
        base_model = MobileNet(weights = 'imagenet', include_top=False,input_shape=input_shape)
    elif name == 'Inception' :
        base_model = InceptionV3(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet50' :
        base_model = ResNet50(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet50V2' :
        base_model = ResNet50V2(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet101' :
        base_model = ResNet101(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet101V2' :
        base_model = ResNet101V2(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet152' :
        base_model = ResNet152(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'ResNet152V2' :
        base_model = ResNet152V2(weights = 'imagenet',include_top=False,input_shape=input_shape)
    elif name == 'Incepresnet' :
        base_model = InceptionResNetV2(weights = 'imagenet',include_top=False,input_shape=input_shape) 
        #base_model.trainable = True
        #for layer in base_model.layers[:-25] :
        #    layer.trainable = True
    elif name == 'seresnet18' : 
        model, preprocess_input = Classifiers.get('seresnet18')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnet34' : 
        model, preprocess_input = Classifiers.get('seresnet34')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnet50' : 
        model, preprocess_input = Classifiers.get('seresnet50')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnet101' : 
        model, preprocess_input = Classifiers.get('seresnet101')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnet152' : 
        model, preprocess_input = Classifiers.get('seresnet152')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnext50' : 
        model, preprocess_input = Classifiers.get('seresnext50')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    elif name == 'seresnext101' : 
        model, preprocess_input = Classifiers.get('seresnext101')
        base_model = model(input_shape=input_shape, weights='imagenet', include_top=False)
    x = base_model.output
    x = L.GlobalAveragePooling2D()(x)
    x = L.Dense(512,activation='relu')(x)
    x = L.Dropout(0.3)(x)
    x = L.Dense(256,activation='relu')(x)
    x = L.Dropout(0.2)(x)
    predictions = L.Dense(N_CLASSES,activation='softmax')(x)
    model = Model(inputs = base_model.input, outputs=predictions) 
    
    metric = ['categorical_accuracy']
    if focal_loss : 
        loss=categorical_focal_loss(gamma=2., alpha=.25) #tfa.losses.SigmoidFocalCrossEntropy(reduction=tf.keras.losses.Reduction.AUTO)
        metric = [tfa.metrics.F1Score(num_classes=5, average="macro"),'categorical_accuracy']
    elif label_smoothing :
        loss=CategoricalCrossentropy(label_smoothing=label_smoothing)
    else :
        loss = 'categorical_crossentropy'
    if SWA :
        opt = tf.keras.optimizers.Adam(lr=1e-5) # roll back
        opt = tfa.optimizers.SWA(opt)
    else :
        opt = 'adam'
        
    optimizer = optimizers.Adam(lr=1e-4)                                      #optimizers.Adam(lr=LEARNING_RATE)
    model.compile(optimizer=optimizer, 
                  loss=loss, 
                  metrics=metric)
    return model

In [None]:
#skf = StratifiedKFold(n_splits=N_FOLDS, random_state=seed, shuffle=True)
skf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=seed)
oof_pred = []; oof_labels = []; history_list = []
oof_pred2 = []; oof_labels2 = [] 
ext_train = ['/CBB0-330.tfrec',
             '/CBB1-75.tfrec',
             '/CBSD0-330.tfrec',
             '/CBSD1-330.tfrec',
             '/CBSD2-330.tfrec',
             '/CBSD3-330.tfrec',
             '/CGM0-330.tfrec',
             '/CGM1-330.tfrec',
             '/CMD0-330.tfrec',
            '/CMD1-330.tfrec',
            '/CMD2-330.tfrec',
            '/CMD3-330.tfrec',
            '/CMD4-330.tfrec',
            '/CMD5-330.tfrec',
            '/CMD6-330.tfrec',
            '/Healthy0-313.tfrec']
ext_valid = []

for fold,(idxT, idxV) in enumerate(skf.split(np.arange(15))):
    if tpu: tf.tpu.experimental.initialize_tpu_system(tpu)
    print(f'\nFOLD: {fold+1}')
    print(f'TRAIN: {idxT} VALID: {idxV}')

    # Create train and validation sets
    TRAIN_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/Id_train%.2i*.tfrec' % x for x in idxT])
    print(len(TRAIN_FILENAMES))
    VALID_FILENAMES = tf.io.gfile.glob([GCS_PATH + '/Id_train%.2i*.tfrec' % x for x in idxV])
    
    EXT_DATA_TRAIN = tf.io.gfile.glob([GCS_PATH_EXT + x for x in ext_train])
    TRAIN_FILENAMES += EXT_DATA_TRAIN
    print(len(TRAIN_FILENAMES))
    
    np.random.shuffle(TRAINING_FILENAMES)

    ct_train = count_data_items(TRAIN_FILENAMES)
    STEPS_PER_EPOCH = ct_train // BATCH_SIZE
    
    ## MODEL
    K.clear_session()
    with strategy.scope():
        #model = model_fn((None, None, CHANNELS), N_CLASSES)
        model = get_model_generalized(MODEL_NAME,(None, None, CHANNELS), N_CLASSES)
        
    model_path = f'model_{fold}.h5'
    es = EarlyStopping(monitor='val_categorical_accuracy', mode='max', 
                       patience=ES_PATIENCE, restore_best_weights=True, verbose=1)
    mc = ModelCheckpoint(f'best_model{fold}.h5', monitor = 'val_categorical_accuracy' , mode = 'max', verbose = 1 , save_best_only = True)
    

    reduce_lr =  ReduceLROnPlateau(monitor = "val_categorical_accuracy", factor = 0.5, patience = 4,
      verbose = 1, mode = "auto", epsilon = 1e-04, cooldown = 0,
      min_lr = 1e-8)


    ## TRAIN
    history = model.fit(x=get_dataset(TRAIN_FILENAMES, labeled=True, ordered=False, repeated=True, augment=True,valid=False).map(to_float32_2), 
                        validation_data=get_dataset(VALID_FILENAMES, labeled=True, ordered=True, repeated=False, augment=False,valid=True).map(to_float32_2), 
                        steps_per_epoch=STEPS_PER_EPOCH, 
                        callbacks=[es, mc ,reduce_lr],  #LearningRateScheduler(lrfn, verbose=0)
                        epochs=EPOCHS,  
                        verbose=2).history
      
    history_list.append(history)
    # Save last model weights
    model.save_weights(model_path)

    # OOF predictions
    ds_valid = get_dataset(VALID_FILENAMES, labeled=True, ordered=True, repeated=False,valid=True, augment=False).map(to_float32_2)
    oof_labels.append([target.numpy() for img, target in iter(ds_valid.unbatch())])
    x_oof = ds_valid.map(lambda image, image_name: image)
    oof_pred.append(np.argmax(model.predict(x_oof), axis=-1))
    model.load_weights(f'best_model{fold}.h5')
    oof_pred2.append(np.argmax(model.predict(x_oof),axis=-1))
    
    ## RESULTS
    print(f"#### FOLD {fold+1} OOF Accuracy = {np.max(history['val_categorical_accuracy']):.3f}")

In [None]:
for fold, history in enumerate(history_list):
    print(f'\nFOLD: {fold+1}')
    plot_metrics(history)

In [None]:
y_true = np.argmax(np.concatenate(oof_labels),axis=1)
y_preds = np.concatenate(oof_pred)

print(classification_report(y_true, y_preds, target_names=CLASSES))

In [None]:
y_preds2 = np.concatenate(oof_pred2)

print(classification_report(y_true, y_preds2, target_names=CLASSES))

In [None]:
from sklearn.metrics import accuracy_score

print(accuracy_score(y_true,y_preds2))

In [None]:
from sklearn.metrics import accuracy_score

print(accuracy_score(y_true,y_preds))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 12))
train_cfn_matrix = confusion_matrix(y_true, y_preds, labels=range(len(CLASSES)))
train_cfn_matrix = (train_cfn_matrix.T / train_cfn_matrix.sum(axis=1)).T
train_df_cm = pd.DataFrame(train_cfn_matrix, index=CLASSES, columns=CLASSES)
ax = sns.heatmap(train_df_cm, cmap='Blues', annot=True, fmt='.2f', linewidths=.5).set_title('Train', fontsize=30)
plt.show()

In [None]:
'''x_samp, y_samp = dataset_to_numpy_util(train_dataset, 18)

x_samp_1, y_samp_1 = x_samp[:9,:,:,:], y_samp[:9]
samp_preds_1 = model.predict(x_samp_1, batch_size=9)
display_9_images_with_predictions(x_samp_1, samp_preds_1, y_samp_1)

x_samp_2, y_samp_2 = x_samp[9:,:,:,:], y_samp[9:]
samp_preds_2 = model.predict(x_samp_2, batch_size=9)
display_9_images_with_predictions(x_samp_2, samp_preds_2, y_samp_2)'''