Hello fellow Kagglers,

This notebook demonstrates how MixUp and CutMix augmentations are applied to achieve a cross validation accuracy of >0.90.


**[MixUp paper](https://arxiv.org/abs/1710.09412)**

MixUp combines two images, *a* and *b*, into one image. Each pixel is *x%* image *a* and *y%* image b, as is the label. The idea behind this augmentation method is to train the model on images with a label which lies in between two classes.

**[CutMix paper](https://arxiv.org/abs/1905.04899)**

CutMix also combines two images, *a* and *b*, and uses complete parts of different images to create a new image without overlap. For example, the left part of image *a* and the right part of image *b*, this results the label to be 0.50 label *a* + 0.50 label *b*. The idea behind this augmentation method is to roughly the same as with MixUp, train the model on a combination of 2 images with a mixed label. In contrast to MixUp, this method only uses original pixels and applies a regional dropout of the image as only a certain part of the original image is used.

**[GridMask paper](https://arxiv.org/pdf/2001.04086.pdf)**

GridMask is one of many image cutout methods, but it distinguishes itself by using a grid shaped mask, hence the name GridMask. The size of the grid is in a certain range and also positioned over the image with a random top and left offset.

If these augmentation methods sound abstract, don't worry, examples will be shown in this notebook.

Validation is performed on a stratified kfold for n=5, thus 20% of the training data is used for validation with equal proportions of samples per class for the training and validation dataset.
Although a cross validation accuracy of >0.90 is achieved the leaderboard score is 0.893, test time augmentation and comining multiple models from different fold could improve this score.


**V4** Added GridMask and changed the batch size from 64 to 32. Also reduced the number of epochs from 25 to 15. Using only CutMix and GridMask augmentations, as this gave the best results. All together, LB score improved from 0.893 to 0.896.

**V5** Added 2019 Competition data (no duplicates). Improved train dataset speed, increased epochs to 30.

**V6** Batch size decreased from 32 -> 16

The inference notebook can be found [here](https://www.kaggle.com/markwijkhuizen/cassava-leaf-disease-inference-5-fold?scriptVersionId=54736925). 8x Test Time Augmentation is applied.

In [None]:
!pip install -q --upgrade pip
!pip install -q efficientnet

In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
import efficientnet.tfkeras as efn
import seaborn as sns

from kaggle_datasets import KaggleDatasets
from tensorflow.keras.mixed_precision import experimental as mixed_precision
from tqdm.notebook import tqdm
from sklearn.metrics import classification_report, confusion_matrix

import sys
import glob
import math
import gc
import time

print(f'tensorflow version: {tf.__version__}')
print(f'tensorflow keras version: {tf.keras.__version__}')
print(f'python version: P{sys.version}')

# TPU and  bfloat16 Configuration

A bfloat16 is a 16 bits float with the range of a 32 bits float, but with a lower precision. Using a bfloat16 instead of a float32 reduces memory consumption and speeds up training and augmentation. The loss in numerical precision is in practice not a problem for machine learning models, as performance, in most cases, won't be affected by a loss of precision after the 3rd decimal number.

**[Some background knowledge on bfloat16](https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus)**

In [None]:
AUTO = tf.data.experimental.AUTOTUNE
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

# # set half precision policy
mixed_precision.set_policy('mixed_bfloat16')

# enable XLA optmizations
tf.config.optimizer.set_jit(True)

print(f'Compute dtype: {mixed_precision.global_policy().compute_dtype}')
print(f'Variable dtype: {mixed_precision.global_policy().variable_dtype}')

In [None]:
IMG_HEIGHT = 600
IMG_WIDTH = 800

IMG_SIZE = 600
IMG_TARGET_SIZE = 512
N_CHANNELS = 3

N_TRAIN_IMGS = 21642
N_VAL_IMGS = 5410
BATCH_SIZE_VAL = 139 * REPLICAS # 5410 / 8 / 4

N_LABELS = 5
N_FOLDS = 5
EPOCHS = 30

BATCH_SIZE_BASE = 16
BATCH_SIZE = BATCH_SIZE_BASE * REPLICAS

TARGET_DTYPE = tf.bfloat16

# ImageNet mean and standard deviation
IMAGENET_MEAN = tf.constant([0.485, 0.456, 0.406], dtype=tf.float32)
IMAGENET_STD = tf.constant([0.229, 0.224, 0.225], dtype=tf.float32)

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-tfrecords-600x600')

# Train Dataset

A public dataset is used where jpegs are combioned into TFRecords. This allows for a faster data pipeline, as images do not have to be read one-by-one, but 1024 at a time. The original jpegs are used for data augmentation purposes. The original images are 800\*600 pixels and each epoch a random square is used.

V5: Image height and width is added as 2019 competition images have a wide variety of resolutions and Tensorflow needs an explicit shape.

In [None]:
def decode_tfrecord_train(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
    })
    
    height = features['height']
    width = features['width']

    image = tf.io.decode_jpeg(features['image'])
    image = tf.reshape(image, [height, width, N_CHANNELS])
    
    # get random square
    if height > width:
        offset = tf.random.uniform(shape=(), minval=0, maxval=height-width, dtype=tf.int64)
        image = tf.slice(image, [offset, 0, 0], [width, width, N_CHANNELS])
    elif width > height:
        offset = tf.random.uniform(shape=(), minval=0, maxval=width-height, dtype=tf.int64)
        image = tf.slice(image, [0, offset, 0], [height, height, N_CHANNELS])
    else:
        image = tf.slice(image, [0, 0, 0], [height, width, N_CHANNELS])
        
    size = tf.cast(height if height < width else width, tf.float32)
    
    # cast label to int8
    label = tf.cast(features['label'], tf.uint8)

    return image, label, size

In [None]:
# chance of x in y to return true, used for conditional data augmentation
def chance(x, y):
    return tf.random.uniform(shape=[], minval=0, maxval=y, dtype=tf.int32) < x

In [None]:
def augment_image(image, label, size):
    # random flip image horizontally
    image = tf.image.random_flip_left_right(image)
    # random flip image vertically
    image = tf.image.random_flip_up_down(image)
    
    # random transpose
    if chance(1,2):
        image = tf.image.transpose(image)
    
    # random crop between 75%-100%
    crop_size = tf.random.uniform(shape=(), minval=size*0.75, maxval=size)
    image = tf.image.random_crop(image, [crop_size, crop_size, N_CHANNELS])
    
    # cast to target dtype and resize
    image = tf.image.resize(image, [IMG_TARGET_SIZE, IMG_TARGET_SIZE])
    
    # normalize according to imagenet mean and std
    image /= 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    # one hot encode label
    label = tf.one_hot(label, N_LABELS, dtype=tf.float32)
    
    return image, label

In [None]:
def read_augment_image(record_bytes):
    image, label, size = decode_tfrecord_train(record_bytes)
    image, label = augment_image(image, label, size)
    
    return image, label

This next function returns a random index from another image in the batch. This method performs better than using a random image with another label. This could be due to the fact the class inbalance is changed by choosing an image from another class. Images from the most dominant class, class 3, will be always mixed with an image from another class, thus the other classes will be more present in the training data.

In [None]:
def get_mix_img_idx(labels_idxs, idx):
    idx_candidates = tf.where(labels_idxs != idx)
    r = tf.random.uniform(minval=0, maxval=len(idx_candidates), shape=[], dtype=tf.int32)
    idx = tf.gather(idx_candidates, r)
    idx = tf.cast(idx, tf.int32)
    idx = tf.squeeze(idx)
    
    return idx

# MixUp implementation

In [None]:
def mixup(images, labels, alpha=0.40):
    l = len(images)
    # get image factors
    a = tfp.distributions.Beta(alpha, alpha).sample(l)
    a_label = tf.reshape(a, shape=(l,1))
    a_label = tf.tile(a_label, [1, N_LABELS])
    b_label = 1 - a_label
    
    a_image = tf.reshape(a, shape=(l,1,1,1))
    a_image = tf.tile(a_image, [1, IMG_TARGET_SIZE, IMG_TARGET_SIZE ,N_CHANNELS])
    a_image = tf.cast(a_image, tf.float32)
    b_image = 1 - a_image
    
    # get mixup image indices
    if l == 2:
        idxs = tf.constant([1, 0])
    else:
        labels_idxs = tf.range(len(labels))
        idxs = tf.map_fn(lambda idx: get_mix_img_idx(labels_idxs, idx), tf.range(len(labels)))
    
    images_mixup = tf.gather(images, idxs)
    labels_mixup = tf.gather(labels, idxs)
    
    # mixup images and labels
    images =  images * a_image + images_mixup * b_image
    labels = labels * a_label + labels_mixup * b_label
    
    images = tf.cast(images, TARGET_DTYPE)
    
    return images, labels

# CutMix implementation

This cutmix implementation differs from the one in version 3. This implementation is closer to the method described in the paper. A random mask whose size is defined by the beta destribution with $Beta(\alpha, \alpha)|\alpha=1$. The original image will be masked with another image, the label will be updated accordingly to label and size of the masking image.

In [None]:
def create_cutmix_mask(a):
    # create random mask size and coordinates
    r_w = tf.cast(IMG_TARGET_SIZE * tf.math.sqrt(1 - a), tf.int32)
    r_h = tf.cast(IMG_TARGET_SIZE * tf.math.sqrt(1 - a), tf.int32)
    
    if r_w == IMG_TARGET_SIZE:
        r_x = 0
    else:
        r_x = tf.random.uniform(minval=0, maxval=IMG_TARGET_SIZE - r_w, shape=[], dtype=tf.int32)
        
    if r_h == IMG_TARGET_SIZE:
        r_y = 0
    else:
        r_y = tf.random.uniform(minval=0, maxval=IMG_TARGET_SIZE - r_w, shape=[], dtype=tf.int32)

    # compute padding sizes
    pad_left = r_x
    pad_right = IMG_TARGET_SIZE - (r_x + r_w)
    pad_top = r_y
    pad_bottom = IMG_TARGET_SIZE - (r_y + r_h)
    
    # create mask_a and mask_b
    mask_a = tf.ones(shape=[r_w, r_h], dtype=tf.float32)
    mask_a = tf.pad(mask_a, [[pad_left, pad_right], [pad_top, pad_bottom]], mode='CONSTANT', constant_values=0)
    mask_a = tf.expand_dims(mask_a, axis=2)
    
    return mask_a

def cutmix(images, labels):
    l = len(images)
    a_float32 = tfp.distributions.Beta(1.0, 1.0).sample([l])

    mask_b = tf.map_fn(create_cutmix_mask, a_float32)
    mask_a = tf.math.abs(mask_b - 1)
    
    # images_idxs
    if l == 2:
        idxs = tf.constant([1, 0])
    else:
        labels_idxs = tf.range(len(labels))
        idxs = tf.map_fn(lambda idx: get_mix_img_idx(labels_idxs, idx), tf.range(len(labels)))
    
    images_cutmix = tf.gather(images, idxs)
    labels_cutmix = tf.gather(labels, idxs)
    
    a_float32_labels = tf.expand_dims(a_float32, axis=1)
    a_float32_labels = tf.repeat(a_float32_labels, N_LABELS, axis=1)
    labels_factor = a_float32_labels
    labels_cutmix_factor = 1 - a_float32_labels
    
    # cutmix images and labels
    images = images * mask_a + images_cutmix * mask_b
    labels = labels * labels_factor + labels_cutmix * labels_cutmix_factor
    
    images = tf.cast(images, TARGET_DTYPE)
    
    return images, labels

# GridMask implementation

In [None]:
def gridmask(images, labels):
    l = len(images)
    
    d = tf.random.uniform(minval=int(IMG_TARGET_SIZE * (96/224)), maxval=IMG_TARGET_SIZE, shape=[], dtype=tf.int32)
    grid = tf.constant([[[0], [1]],[[1], [0]]], dtype=tf.float32)
    grid = tf.image.resize(grid, [d, d], method='nearest')
    
    # 50% chance to rotate mask
    if chance(1, 2):
        grid = tf.image.rot90(grid, 1)

    repeats = IMG_TARGET_SIZE // d + 1
    grid = tf.tile(grid, multiples=[repeats, repeats, 1])
    grid = tf.image.random_crop(grid, [IMG_TARGET_SIZE, IMG_TARGET_SIZE, 1])
    grid = tf.expand_dims(grid, axis=0)
    grid = tf.tile(grid, multiples=[l, 1, 1, 1])

    images = images * grid
    images = tf.cast(images, TARGET_DTYPE)
    
    return images, labels

In [None]:
def augment_batch(images, labels, augmentations=None):
    if augmentations is None:
        r = tf.random.uniform(minval=0, maxval=4, shape=[], dtype=tf.int32)
    else:
        r = tf.random.uniform(minval=0, maxval=len(augmentations), shape=[], dtype=tf.int32)
        r = tf.gather(augmentations, r)
        
    if r == 0:
        images = tf.cast(images, TARGET_DTYPE)
        return images, labels
    elif r == 1:
        return mixup(images, labels)
    elif r == 2:
        return cutmix(images, labels)
    elif r == 3:
        return gridmask(images, labels)
    else:
        images = tf.cast(images, TARGET_DTYPE)
        return images, labels

Reshaping the batch from [x,y,...] to [x*y,...]. This greatly improves the dataset speed in comparison with unbatching and batching.

In [None]:
def reshape_batch(images, labels):
    images = tf.reshape(images, shape=[BATCH_SIZE, IMG_TARGET_SIZE, IMG_TARGET_SIZE, N_CHANNELS])
    labels = tf.reshape(labels, shape=[BATCH_SIZE, N_LABELS])
    
    random_idxs = tf.random.shuffle(tf.range(BATCH_SIZE))
    images = tf.gather(images, random_idxs)
    labels = tf.gather(labels, random_idxs)
    
    return images, labels

V3: Improved dataset pipeline speed by adding a prefetch for the TFRecords samples and a static number of parallel calls for the batch augmentations

V4: Using a reshape map instead of unbatch and batch, greatly improves throughput by ~30%

In [None]:
def get_train_dataset(bs=BATCH_SIZE, fold=0, augmentations=None):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    
    FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/fold_{fold}/train/*.tfrecords')
    train_dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS, num_parallel_reads=AUTO)
    train_dataset = train_dataset.with_options(ignore_order)
    train_dataset = train_dataset.prefetch(AUTO)
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.map(read_augment_image, num_parallel_calls=AUTO)

    train_dataset = train_dataset.batch(BATCH_SIZE_BASE)
    train_dataset = train_dataset.map(lambda images, labels: augment_batch(images, labels, augmentations=augmentations), num_parallel_calls=REPLICAS)
    
    train_dataset = train_dataset.batch(REPLICAS)
    train_dataset = train_dataset.map(reshape_batch, num_parallel_calls=1)
    
    train_dataset = train_dataset.prefetch(1)
    
    return train_dataset

train_dataset = get_train_dataset()

In [None]:
def benchmark(num_epochs=3, n_steps_per_epoch=10, augmentations=None, bs=BATCH_SIZE):
    dataset = get_train_dataset(augmentations=augmentations)
    start_time = time.perf_counter()
    for epoch_num in range(num_epochs):
        epoch_start = time.perf_counter()
        for idx, (images, labels) in enumerate(dataset.take(n_steps_per_epoch)):
            if idx is 1:
                print(images.shape, labels.shape)
            pass
        print(f'epoch {epoch_num} took: {round(time.perf_counter() - epoch_start, 2)}')
    print("Execution time:", round(time.perf_counter() - start_time, 2))
    
benchmark(num_epochs=3, augmentations=[2,3])

# Batch Example

The next function plots examples of the final augmented images. The title of each image shows the RGB and label information. Note all possible augmentations are shown here, only CutMix and GridMask are used for training.

In [None]:
def show_first_train_batch(augmentations=None, rows=4, cols=4, print_info=False):
    # log info of batch and first few train images
    imgs, lbls = next(iter(get_train_dataset(augmentations=augmentations)))
    if print_info:
        print(f'Number of train images: {N_TRAIN_IMGS}')
        print(f'imgs.shape: {imgs.shape}, images.dtype: {imgs.dtype}, lbls.shape: {lbls.shape}, lbls.dtype: {lbls.dtype}')
        img0 = imgs[0].numpy().astype(np.float32)
        print('img0 mean: {:.3f}, img0 std {:.3f}, img0 min: {:.3f}, img0 max: {:.3f}'.format(img0.mean(), img0.std(), img0.min(), img0.max()))
        print(f'first label: {lbls[0]}')

    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*6))
    for r in range(rows):
        for c in range(cols):
            img = imgs[r*rows+c].numpy().astype(np.float32)
            lbl = lbls[r*rows+c].numpy().astype(np.float32).tolist()
            
            # add title with image information
            lbl_str = '[' + ', '.join(['%.3f' % i for  i in lbl]) + ']'
            axes[r, c].set_title('mean: {:.3f}, std {:.3f}, min: {:.3f}, max: {:.3f}\n label: {}'.format(img.mean(), img.std(), img.min(), img.max(), lbl_str))
            axes[r, c].axhline(y=IMG_TARGET_SIZE // 2, color='r')
            axes[r, c].axvline(x=IMG_TARGET_SIZE // 2, color='r')
            
            img += abs(img.min())
            img /= img.max()
            axes[r, c].imshow(img)
            
show_first_train_batch(augmentations=[0,1,2,3])

In [None]:
# MixUp examples
show_first_train_batch(augmentations=[1], rows=2, cols=3)

In [None]:
# CutMix examples
show_first_train_batch(augmentations=[2], rows=2, cols=3)

In [None]:
# GridMask examples
show_first_train_batch(augmentations=[3], rows=2, cols=3)

This function shows the per image augmentation. This is the augmentation before CutMix or MixUp is applied ans shows how images can differ each epoch.

In [None]:
def resize_image(image, label, size):
    image = tf.image.resize(image, [IMG_TARGET_SIZE, IMG_TARGET_SIZE])
    
    return image, label, tf.cast(IMG_TARGET_SIZE, tf.float32)

def show_data_augmentations():
    FNAMES_TRAIN_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/fold_0/train/*.tfrecords')
    dataset = tf.data.TFRecordDataset(FNAMES_TRAIN_TFRECORDS)
    dataset = dataset.map(decode_tfrecord_train)
    dataset = dataset.map(resize_image)
    dataset = dataset.batch(BATCH_SIZE)

    imgs, lbls, szs = next(iter(dataset))
    
    # to test data augmentation
    rows, cols = 4, 4
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*6))
    for r in range(rows):
        for c in range(cols):
            img, _ = augment_image(imgs[15], -1, szs[15])
            img = img.numpy().astype(np.float32)
            
            # add title with image information
            axes[r, c].set_title('mean: {:.3f}, std {:.3f}, min: {:.3f}, max: {:.3f}'.format(img.mean(), img.std(), img.min(), img.max()))
            
            img += abs(img.min())
            img /= img.max()
            
            axes[r, c].imshow(img)
                
show_data_augmentations()

# Validation Dataset

In [None]:
def decode_tfrecord_val(record_bytes):
    features = tf.io.parse_single_example(record_bytes, {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
    })
    
    height = features['height']
    width = features['width']

    image = tf.io.decode_jpeg(features['image'])
    image = tf.reshape(image, [height, width, N_CHANNELS])
    
    # get random square
    if height > width:
        offset = (height - width) // 2
        image = tf.slice(image, [offset, 0, 0], [width, width, N_CHANNELS])
    elif width > height:
        offset = (width - height) // 2
        image = tf.slice(image, [0, offset, 0], [height, height, N_CHANNELS])
    else:
        image = tf.slice(image, [0, 0, 0], [height, width, N_CHANNELS])
    
    # resize to target size
    image = tf.image.resize(image, [IMG_TARGET_SIZE, IMG_TARGET_SIZE])
    
    # normalize according to imagenet mean and std
    image /= 255.0
    image = (image - IMAGENET_MEAN) / IMAGENET_STD
    
    # cast to TARGET_DTYPE
    image = tf.cast(image, TARGET_DTYPE)
    
    label = tf.cast(features['label'], tf.int32)
    
    # one hot encode label
    label = tf.one_hot(label, N_LABELS, dtype=tf.int32)
    
    return image, label

In [None]:
def get_val_dataset(bs=BATCH_SIZE, fold=0):
    FNAMES_VAL_TFRECORDS = tf.io.gfile.glob(f'{GCS_DS_PATH}/fold_{fold}/val/*.tfrecords')
    val_dataset = tf.data.TFRecordDataset(FNAMES_VAL_TFRECORDS, num_parallel_reads=AUTO)
    val_dataset = val_dataset.prefetch(BATCH_SIZE_VAL)
    val_dataset = val_dataset.repeat()
    val_dataset = val_dataset.map(decode_tfrecord_val, num_parallel_calls=AUTO)
    val_dataset = val_dataset.batch(bs, drop_remainder=True)
    val_dataset = val_dataset.prefetch(1)
    
    return val_dataset

val_dataset = get_val_dataset()

In [None]:
# Show batch info and first few test images
def show_first_val_batch():
    imgs, lbls = next(iter(val_dataset))
    
    print(f'Number of val images: {N_VAL_IMGS}')
    print(f'imgs.shape: {imgs.shape}, images.dtype: {imgs.dtype}, lbls.shape: {lbls.shape}, lbls.dtype: {lbls.dtype}')
    img0 = imgs[0].numpy().astype(np.float32)
    print('img0 mean: {:.3f}, img0 std {:.3f}, img0 min: {:.3f}, img0 max: {:.3f}'.format(img0.mean(), img0.std(), img0.min(), img0.max()))
    print(f'first label: {lbls[0]}')

    rows, cols = 4, 5
    fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*6, rows*6))
    for r in range(rows):
        for c in range(cols):
            img = imgs[r * (rows + 1) + c].numpy().astype(np.float32)
            
            # add title with image information
            axes[r, c].set_title('mean: {:.3f}, std {:.3f}, min: {:.3f}, max: {:.3f}'.format(img.mean(), img.std(), img.min(), img.max()))
            
            img += abs(img.min())
            img /= img.max()

            axes[r, c].imshow(img)
            
show_first_val_batch()

# Learning Rate Scheduler

The learning rate used is a exponential warmup with cosine decay. The warmup is used to prevent the model from early overfitting on the first images. When the model starts learning the loss will be high as the model is trained on ImageNet, not on the training dataset. When starting with a high learning rate the model will learn the first few batches very well due to the high loss and could overfit on those samples. When starting with a very low learning rate the model will see all training images and make small adjustment to the weights and therefore learn from all training images equally when the loss is high and weights are modified strongly.

In [None]:
def lrfn(epoch, bs=BATCH_SIZE, epochs=EPOCHS):
    # Config
    LR_START = 1e-6
    LR_MAX = 2e-4
    LR_FINAL = 1e-6
    LR_RAMPUP_EPOCHS = 4
    LR_SUSTAIN_EPOCHS = 0
    DECAY_EPOCHS = epochs  - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1
    LR_EXP_DECAY = (LR_FINAL / LR_MAX) ** (1 / (EPOCHS - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS - 1))

    if epoch < LR_RAMPUP_EPOCHS: # exponential warmup
        lr = LR_START + (LR_MAX + LR_START) * (epoch / LR_RAMPUP_EPOCHS) ** 2.5
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS: # sustain lr
        lr = LR_MAX
    else: # cosine decay
        epoch_diff = epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS
        decay_factor = (epoch_diff / DECAY_EPOCHS) * math.pi
        decay_factor= (tf.math.cos(decay_factor).numpy() + 1) / 2        
        lr = LR_FINAL + (LR_MAX - LR_FINAL) * decay_factor

    return lr

In [None]:
# plots the learning rate schedule
def show_lr_schedule(bs=BATCH_SIZE, epochs=EPOCHS):
    rng = [i for i in range(epochs)]
    y = [lrfn(x, bs=bs, epochs=epochs) for x in rng]
    x = np.arange(epochs)
    x_axis_labels = list(map(str, np.arange(1, epochs+1)))
    print('init lr {:.1e} to {:.1e} final {:.1e}'.format(y[0], max(y), y[-1]))
    
    plt.figure(figsize=(30, 10))
    plt.xticks(x, x_axis_labels, fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.yticks(fontsize=16)
    plt.plot(rng, y)
    plt.grid()
    plt.show()
    
show_lr_schedule()

# Model

In [None]:
def get_model():
    # reset to free memory and training variables
    tf.keras.backend.clear_session()
    
    with strategy.scope():
        
        net = efn.EfficientNetB4(
            include_top=False,
            weights='noisy-student',
            input_shape=(IMG_TARGET_SIZE, IMG_TARGET_SIZE, N_CHANNELS),
        )
        
        for layer in reversed(net.layers):
            if isinstance(layer, tf.keras.layers.BatchNormalization):
                layer.trainable = False
            else:
                layer.trainable = True
        
        model = tf.keras.Sequential([
            net,
            tf.keras.layers.Dropout(0.45),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dropout(0.45),
            tf.keras.layers.Dense(N_LABELS, activation='softmax', dtype=tf.float32),
        ])

        # add metrics
        metrics = [
            tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
            tf.keras.metrics.TopKCategoricalAccuracy(k=2, name='top_2_accuracy'),
        ]

        optimizer = tf.keras.optimizers.Adam()
        loss = tf.keras.losses.CategoricalCrossentropy()

        model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

        return model

# Validation Report

In [None]:
def show_validation_report_per_class(model, dataset, steps, name, bs):
    print(f'--- {name} REPORT ---')
    # classification report
    y = np.ndarray(shape=steps * bs, dtype=np.uint16)
    y_pred = np.ndarray(shape=steps * bs, dtype=np.uint16)
    for idx, (images, labels) in tqdm(enumerate(dataset.take(steps)), total=steps):
        with tf.device('cpu:0'):
            y[idx*bs:(idx+1)*bs] = np.argmax(labels, axis=1)
            y_pred[idx*bs:(idx+1)*bs] = np.argmax(model.predict(images).astype(np.float32), axis=1)
            
    print(classification_report(y, y_pred))
    
    # Confusion matrix
    fig, ax = plt.subplots(1, 1, figsize=(20, 12))
    cfn_matrix = confusion_matrix(y, y_pred, labels=range(N_LABELS))
    cfn_matrix = (cfn_matrix.T / cfn_matrix.sum(axis=1)).T
    df_cm = pd.DataFrame(cfn_matrix, index=np.arange(N_LABELS), columns=np.arange(N_LABELS))
    ax = sns.heatmap(df_cm, cmap='Blues', annot=True, fmt='.3f', linewidths=.5, annot_kws={'size':14}).set_title(f'{name} CONFUSION MATRIX')
    plt.xticks(fontsize=16)
    plt.yticks(fontsize=16)
    plt.xlabel('PREDICTED', fontsize=24, labelpad=10)
    plt.ylabel('ACTUAL', fontsize=24, labelpad=10)
    plt.show()

# Training History

In [None]:
def plot_history_metric(history, metric):
    TRAIN_EPOCHS = len(history.history['loss'])
    x = np.arange(TRAIN_EPOCHS)
    x_axis_labels = list(map(str, np.arange(1, TRAIN_EPOCHS+1)))
    val = 'val' in ''.join(history.history.keys())
    # summarize history for accuracy
    plt.figure(figsize=(20, 10))
    plt.plot(history.history[metric])
    if val:
        plt.plot(history.history[f'val_{metric}'])
    
    plt.title(f'Model {metric}', fontsize=24)
    plt.ylabel(metric, fontsize=20)
    plt.yticks(fontsize=16)
    plt.xlabel('epoch', fontsize=20)
    plt.xticks(x, x_axis_labels, fontsize=16) # set tick step to 1 and let x axis start at 1
    plt.legend(['train'] + ['test'] if val else [], loc='upper left')
    plt.grid()
    plt.show()

# Training

This is the training loop, the training metrics and confusion matrix are displayed after each fold.

From the validation report it can be observed the model predicts label 3 with very high precision and accuracy. This is not surprising as label 3 is by far the most common label and the model will therefore most likely get biased towards this label. Label 0 is the least common label and also has the lowest precision and recall. Label 0 is more than 10 times less common than label 3, making the dataset highly unbalanced.

The confusion matrix shows how the model mixes up labels. Label 0 is mostly confused with label 4 and 1. Moreover, label 2 is often confused with label 3.

In [None]:
print(f'TRAINING FOR {EPOCHS} EPOCHS WITH BATCH SIZE {BATCH_SIZE}\n')
print(f'TRAIN IMAGES: {N_TRAIN_IMGS}, VAL IMAGES: {N_VAL_IMGS}\n')

augmentations_dic = dict({
    0: 'None',
    1: 'MixUp',
    2: 'CutMix',
    3: 'GridMask',
})

MEAN_VAL_ACC = []
augmentations = [2,3] # only CutMix and GridMask is used
fold = 0
epochs = EPOCHS

for idx, fold in enumerate(range(N_FOLDS)):
    # callbacks
    lr_callback = tf.keras.callbacks.LearningRateScheduler(lambda epoch: lrfn(epoch, epochs=epochs), verbose=1)
    show_lr_schedule(epochs=epochs)
    
    # get the model
    model = get_model()
    
    if idx is 0:
        # model summary
        model.summary()
        # compute and variable data types
        print(f'Compute dtype: {mixed_precision.global_policy().compute_dtype}')
        print(f'Variable dtype: {mixed_precision.global_policy().variable_dtype}')
        
    print('\n')
    print('*'*25, f'augmentations {augmentations}', '*'*25, '\n')
    print(f'fold: {fold}, epochs: {epochs}')
    print(' AND '.join([augmentations_dic.get(i) for i in augmentations]), '\n')
    
    train_dataset = get_train_dataset(bs=BATCH_SIZE, fold=fold, augmentations=augmentations)
    val_dataset = get_val_dataset(bs=BATCH_SIZE_VAL, fold=fold)
    
    history = model.fit(
        train_dataset,
        steps_per_epoch = N_TRAIN_IMGS // BATCH_SIZE,

        validation_data = val_dataset,
        validation_steps = N_VAL_IMGS // BATCH_SIZE_VAL,

        epochs = epochs,
        callbacks = [
            lr_callback,
        ],
        verbose=1,
    )
    
    # add val accuracy to list
    MEAN_VAL_ACC.append(history.history['val_accuracy'][-1])
    
    # plot training histories
    plot_history_metric(history, 'loss')
    plot_history_metric(history, 'accuracy')
    plot_history_metric(history, 'top_2_accuracy')
    
    # show train and validation report
    show_validation_report_per_class(model, val_dataset, N_VAL_IMGS // BATCH_SIZE_VAL, 'VALIDATION', BATCH_SIZE_VAL)

    # save the model
    model.save_weights(f'model_fold_{fold}_weights.h5')
    
    del model, train_dataset, val_dataset
    gc.collect()

In [None]:
print(f'OOF validation accuracy: {np.array(MEAN_VAL_ACC).mean()}')