# Random Erasing/CutOut, MixUp and CutMix Comparison using TPU

In this notebook I want to explore three extremely interesting method of data augmentation that I have recently used in [Plant Pathology 2020 - FGVC7](https://www.kaggle.com/c/plant-pathology-2020-fgvc7) competition: **Random Erasing**, **MixUp** and **CutMix**.

The goal of this notebook is to do some experimentation using these three methods on TPU, comparing the results obtained using a simple model based on EfficientNetB0 network architecture on [Flower Classification with TPUs](https://www.kaggle.com/c/flower-classification-with-tpus) dataset.

This notebook is not intended for providing general results but it gives only a glimpse about the usefulness of some recent data augmentation methods in the context of image recognition.

I have tried to implement Random Erasing, MixUp and CutMix functions from scratch starting from the description contained in the original papers (section "Related Papers and Links") but I have also referred to two two excellent notebooks from *Martin GÃ¶rner* and *Chris Deotte* that I reccomend to read and upvote:
* [Getting started with 100+ flowers on TPU](https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu) contains all the explanation about how to use TPUs on Kaggle with lots of helper function specific for the flower dataset.
* [CutMix and MixUp on GPU/TPU](https://www.kaggle.com/cdeotte/cutmix-and-mixup-on-gpu-tpu) contains excellent Tensorflow implementations of MixUp and CutMix from which I took some ideas and tricks, the main one was how to cope with the difficulty in Tensorflow 2.2 of modifying directly some entries in a tensor.

For understanding how to organize the notebook (this is my first one!) I took ispiration from some notebooks by *Aleksandra Deis* that are always perfectly organized.

# Libraries

In [None]:
import math, re, os, gc
from os import listdir
from os.path import isfile, join

import tensorflow as tf
import numpy as np
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

from matplotlib import pyplot as plt
from matplotlib.ticker import PercentFormatter
plt.rcParams['figure.figsize'] = [20, 10]
#plt.style.use('seaborn-deep')

# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)

from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
print("Tensorflow version " + tf.__version__)

In [None]:
# EfficientNet
!pip install efficientnet
import efficientnet.tfkeras as efn

# TPU Configuration

In [None]:
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

# Parameters

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path() # you can list the bucket with "!gsutil ls $GCS_DS_PATH"

AUTO = tf.data.experimental.AUTOTUNE

# Seed for random number generation
SEED = 42

# Image size
IMAGE_SIZE = [224, 224] 
N_CHANNELS = 3

# Data path
GCS_PATH_SELECT = {192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',  # available image sizes
                   224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
                   331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
                   512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'}

GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]]

# Training parameters
MODEL_NAME = 'effb0'
EPOCHS = 50
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

N_ENSEMBLE = 2  # number of identical models to train

In order to perform the analisis on a labelled test set, I have split the original validation set in two subsets each containing rougly 15% of the labelled data. Finally, I used these two new sets for validation and test.

In [None]:
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
# Split original validation set in two new disjoint subset for validation and test.
VALIDATION_TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
VALIDATION_FILENAMES = VALIDATION_TEST_FILENAMES[:8]
TEST_FILENAMES = VALIDATION_TEST_FILENAMES[8:]


CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

## Visualization Utilities

In [None]:
def mixed_label(x):
    ind = [i for i in range(x.shape[0]) if x[i] > 0]
    if len(ind) == 1:
        mixed_lab = CLASSES[ind[0]]
    else:
        mixed_lab = str(np.round(x[ind[0]], 3)) + ' ' + CLASSES[ind[0]] + '\n' + str(np.round(x[ind[1]], 3)) + ' ' + CLASSES[ind[1]]
    try:
        foo = mixed_lab
    except:
        foo = 'ERROR!'
    return foo

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
        numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 15.0 #13.0
    SPACING = 0.2 #0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        try:
            title = '' if label is None else CLASSES[label]
        except:
            title = mixed_label(label)
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        #dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        dynamic_titlesize = 13.0*0.1/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images

        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

def display_confusion_matrix(cmat, score, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Reds')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if score is not None:
        titlestring += 'f1 = {:.3f} '.format(score)
    if precision is not None:
        titlestring += '\nprecision = {:.3f} '.format(precision)
    if recall is not None:
        titlestring += '\nrecall = {:.3f} '.format(recall)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring, fontdict={'fontsize': 18, 'horizontalalignment':'right', 'verticalalignment':'top', 'color':'#804040'})
    plt.show()

In [None]:
def weights_calc(labels):
    """
    Utility function for histogram weights calculation.
    """
    return np.ones(len(labels)) / (len(labels))


def autolabel(rects):
    """
    Utility function for attaching a text label above each bar in rects.
    Credits: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/barchart.html
    """
    for rect in rects:
        height = rect.get_height()
        ax.annotate('{}'.format(np.round(height, 3)),
                    xy=(rect.get_x() + rect.get_width() / 2, height),
                    xytext=(0, 3),  # 3 points vertical offset
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=16)

        
def plot_metric(baseline, agumented_0, agumented_1, agumented_2, metric='loss', aug_names=['NoAugmentation', 'RandomErasing', 'MixUp', 'CutMix']):
    
    if metric == 'loss':
        legend_location = 'upper right'
        y_min = 0.0
        y_max = 2.81
        y_step = 0.2
    else: 
        legend_location = 'lower right'
        y_min = 0.4
        y_max = 1.01
        y_step = 0.05
    
    plt.subplot(2, 2, 1)
    plt.plot(baseline[metric])
    plt.plot(baseline['val_' + metric])
    plt.title('Model ' + metric.capitalize() + ' - ' + aug_names[0], fontsize=18)
    plt.ylabel(metric.capitalize(), fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.xticks(np.arange(0, 51, step=5))
    plt.yticks(np.arange(y_min, y_max, step=y_step))
    plt.legend(['Train', 'Val'], loc=legend_location, fontsize=12)
    plt.grid(axis='y')


    plt.subplot(2, 2, 2)
    plt.plot(agumented_0[metric])
    plt.plot(agumented_0['val_' + metric])
    plt.title('Model ' + metric.capitalize() + ' - ' + aug_names[1], fontsize=18)
    plt.ylabel(metric.capitalize(), fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.xticks(np.arange(0, 51, step=5))
    plt.yticks(np.arange(y_min, y_max, step=y_step))
    plt.legend(['Train', 'Val'], loc=legend_location, fontsize=12)
    plt.grid(axis='y')
    
    plt.subplot(2, 2, 3)
    plt.plot(agumented_1[metric])
    plt.plot(agumented_1['val_' + metric])
    plt.title('Model ' + metric.capitalize() + ' - ' + aug_names[2], fontsize=18)
    plt.ylabel(metric.capitalize(), fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.xticks(np.arange(0, 51, step=5))
    plt.yticks(np.arange(y_min, y_max, step=y_step))
    plt.legend(['Train', 'Val'], loc=legend_location, fontsize=12)
    plt.grid(axis='y')


    plt.subplot(2, 2, 4)
    plt.plot(agumented_2[metric])
    plt.plot(agumented_2['val_' + metric])
    plt.title('Model ' + metric.capitalize() + ' - ' + aug_names[3], fontsize=18)
    plt.ylabel(metric.capitalize(), fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.xticks(np.arange(0, 51, step=5))
    plt.yticks(np.arange(y_min, y_max, step=y_step))
    plt.legend(['Train', 'Val'], loc=legend_location, fontsize=12)
    plt.grid(axis='y')

    plt.subplots_adjust(wspace=0.2, hspace=0.35)
    plt.show()
    

def plot_loss_error(baseline, agumented, augmentation_name='', set_type='validation'):
    """
    Example: plot_loss_error(hist_no_augmentation[0].history, hist_random_erasing[0].history, augmentation_name='Random Erasing', set_type='validation')
    """
    
    if set_type == 'validation':
        loss = 'val_loss'
        acc = 'val_categorical_accuracy'
    else:
        loss = 'loss'
        acc = 'categorical_accuracy'
        
    plt.subplot(1, 2, 1)
    plt.plot(baseline[loss])
    plt.plot(agumented[loss])
    locs, labels = plt.yticks()
    plt.yticks(np.arange(0, 1.41, step=0.2))
    plt.title('Model Loss - '+ 'No Augmentation vs '+ augmentation_name + ' (' + set_type + ' set)')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['No Augmentation', augmentation_name], loc='upper right')

    error_baseline = [1-x for x in baseline[acc]]
    error_agumented = [1-x for x in agumented[acc]]
    
    plt.subplot(1, 2, 2)
    plt.plot(error_baseline)
    plt.plot(error_agumented)
    locs, labels = plt.yticks()
    plt.yticks(np.arange(0, 0.41, step=0.05))
    plt.title('Model Error - '+ 'No Augmentation vs '+ augmentation_name + ' (' + set_type + ' set)')
    plt.ylabel('Top-1 Error')
    plt.xlabel('Epoch')
    plt.legend(['No Augmentation', augmentation_name], loc='upper right')

    plt.show()

    
def plot_loss_error_all(baseline, agumented_0, agumented_1, agumented_2, set_type='validation'):

    if set_type == 'validation':
        loss = 'val_loss'
        acc = 'val_categorical_accuracy'
    else:
        loss = 'loss'
        acc = 'categorical_accuracy'
    
    # Model Loss subplot
    plt.subplot(1, 2, 1)
    plt.plot(baseline[loss])
    plt.plot(agumented_0[loss])
    plt.plot(agumented_1[loss])
    plt.plot(agumented_2[loss])
    locs, labels = plt.yticks()
    plt.xticks(np.arange(0, 51, step=5))
    plt.yticks(np.arange(0, 1.41, step=0.2))
    plt.title('Model Loss on ' + set_type.capitalize() + ' Set' +'\n'+ 'NoAugmentation vs RandomErasing vs MixUp vs CutMix', fontsize=18)
    plt.ylabel('Loss', fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(['No Augmentation', 'Random Erasing', 'MixUp', 'CutMix'], loc='upper right')
    plt.grid(axis='y')

    # Model Error subplot
    error_baseline = [1-x for x in baseline[acc]]
    error_agumented_0 = [1-x for x in agumented_0[acc]]
    error_agumented_1 = [1-x for x in agumented_1[acc]]
    error_agumented_2 = [1-x for x in agumented_2[acc]]
    
    plt.subplot(1, 2, 2)
    plt.plot(error_baseline)
    plt.plot(error_agumented_0)
    plt.plot(error_agumented_1)
    plt.plot(error_agumented_2)
    plt.xticks(np.arange(0, 51, step=5))
    locs, labels = plt.yticks()
    plt.yticks(np.arange(0, 0.41, step=0.05))
    plt.title('Model Error on ' + set_type.capitalize() + ' Set' +'\n'+ 'NoAugmentation vs RandomErasing vs MixUp vs CutMix', fontsize=18)
    plt.ylabel('Top-1 Error', fontsize=16)
    plt.xlabel('Epoch', fontsize=16)
    plt.legend(['No Augmentation', 'Random Erasing', 'MixUp', 'CutMix'], loc='upper right')
    plt.grid(axis='y')

    plt.show()

## Other Utility Functions

In [None]:
def accuracy(y_pred, y):
    """
    Calculate the accuracy.
    Args:
        y_pred: list, predicted labels.
        y: list, true labels.
    Returns:
        Accuracy measure (float).
    """
    return sum([y_pred[i] == y[i] for i in range(len(y))]) / len(y)

# Data

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    #label = tf.cast(example['class'], tf.int32)
    label = tf.one_hot(indices=tf.cast(example['class'], tf.int32), depth=len(CLASSES))
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord if labeled else read_unlabeled_tfrecord, num_parallel_calls=AUTO)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset(augmentation='base'):
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
    
    if augmentation == 'random_erasing':
        dataset = dataset.map(random_erasing, num_parallel_calls=AUTO)
    elif augmentation == 'mixup':
        dataset = dataset.map(mixup, num_parallel_calls=AUTO)
    elif augmentation == 'cutmix':
        dataset = dataset.map(cutmix, num_parallel_calls=AUTO)
        
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(ordered=False):
    dataset = load_dataset(VALIDATION_FILENAMES, labeled=True, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(ordered=False, labeled=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=labeled, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE
print('Dataset: {} training images, {} validation images, {} test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
# Training and Validation labels
labels_train = []
for i, x in enumerate(iter(get_training_dataset().unbatch().batch(1))):
    labels_train.append(np.argmax(x[1].numpy()))
    if i >= NUM_TRAINING_IMAGES:
        break
labels_val = [np.argmax(x[1].numpy()) for x in iter(get_validation_dataset().unbatch().batch(1))] 
#labels_test = [np.argmax(x[1].numpy()) for x in iter(get_test_dataset(labeled=True).unbatch().batch(1))]

In [None]:
# Ordered Test labels
labels_test = np.concatenate([y for x, y in get_test_dataset(labeled=True, ordered=True)], axis=0)
labels_test = [x.argmax() for x in labels_test]

In [None]:
data = [labels_train, labels_val, labels_test]
weights = [weights_calc(labels_train), weights_calc(labels_val), weights_calc(labels_train)]
weights2=np.ones(len(labels_val)) / (len(labels_val))
weights3=np.ones(len(labels_train)) / (len(labels_train))
plt.hist(data, weights=[weights3, weights2, weights2] , bins=104, label=['Train', 'Validation', 'Test'])
plt.legend(loc='upper right')
plt.title('Class distribution', fontsize=20)
plt.xlabel('Classes', fontsize=16)
plt.ylabel('Percentage in the set', fontsize=16)

plt.gca().yaxis.set_major_formatter(PercentFormatter(1))
plt.show()

As we can see in the above graph, the class distribution is quite similar among the three sets of data (Training, Validation and Test sets).

## Dataset Structure

In [None]:
# data dump
print("Training data shapes:")
for image, label in get_training_dataset().take(2):
    print(image.numpy().shape, label.numpy().shape)
print()
print("Training data label examples:\n", label.numpy()[:2])
print()
print("Validation data shapes:")
for image, label in get_validation_dataset().take(2):
    print(image.numpy().shape, label.numpy().shape)
print()
print("Validation data label examples:\n", label.numpy()[:2])
print()
print("Test data shapes:")
for image, idnum in get_test_dataset().take(2):
    print(image.numpy().shape, idnum.numpy().shape)
print()
print("Test data IDs:\n", idnum.numpy().astype('U')) # U=unicode string

# Data Augmentation

### Random Erasing / Cutout

In [None]:
def random_erasing(images, labels, p=0.5, s_l=0.02, s_h=0.4, r_1=0.3, r_2=1/0.3, 
                   erasing_type='RE-R', batch=BATCH_SIZE, rand_seed=SEED, max_attempts=100):
    """
    Perform Random Erasing over a batch of images, following the original 
    description in "https://arxiv.org/abs/1708.04896".
    This function is meant to be used with TPUs.
    
    Args:
        images: a batch of images, i.e. a tensor of size [BATCH_SIZE, WIDE, HEIGH, N_CHANNELS]. 
                Pixels should be normalized in [0, 1].
        labels: a batch of tf tensor, image labels using one-hot encoding
        p: float in [0, 1], probability of performing random erasing over an image.
        s_l: float in [0, 1], minimum percentage of the random erased rectangle 
             surface with respect the image area.
        s_h: float in [0, 1], maximum percentage of the random erased rectangle 
             surface with respect the image area.
        r_1: float, minimum aspect ratio (height / width) of the random erased 
             rectangle area.
        r_2: float, maximum aspect ratio (height / width) of the random erased 
             rectangle area.
        erasing_type: how to fill the random erased rectangle areas:
                      * "RE-R": Each pixel is assigned with a random value
                      * "RE-M": All pixels are assigned with the mean ImageNet 
                                pixel value, i.e. [125, 122, 114] / 255
                      * "RE-1": All pixels assigned with 1
                      Otherwise "RE-0" is implemented (all pixels assigned with 0)
        batch: positive integer, batch size.
        rand_seed: integer, random seed.
        max_attempts: maximum number of attempts in order to find a valid patch 
                      position.
    
    Returns:
        A tuple (batch of images, batch of labels) modified using Random Erasing.
    """
    # modified images
    mod_img = []
    
    # Image dimensions
    h = images.shape[1]  # image height
    w = images.shape[2]  # image width
    n_chan = images.shape[3]  # number of channels
    n_classes = labels.shape[1]  # number of classes

    for i in range(batch):
        q = np.random.uniform()  # random probability using tensorflow
        if q >= p:
            mod_img.append(images[i])  # the image is kept unchanged
        else:  # Random Erasing
            attempts = 0
            while True and attempts < max_attempts:
            
                # Area of the erased rectangle patch
                s_e = np.random.uniform(low=s_l, high=s_h) * h * w
                # Aspect ratio of the erased rectangle patch
                r_e = np.random.uniform(low=r_1, high=r_2)
            
                h_e = int((s_e * r_e)**0.5)  # height of the Random Erased patch
                w_e = int((s_e / r_e)**0.5)  # width of the Random Erased patch
            
                # (x_e, y_e) are the coordinates of the upper-left corner of the 
                # Random Erased patch.
                x_e = int(np.random.uniform(low=0, high=h))
                y_e = int(np.random.uniform(low=0, high=w))

                if (x_e + w_e <= w) and (y_e + h_e <= h):  # assess if the patch top-left corner position is valid
                    
                    # Choose the type of patch
                    if erasing_type == 'RE-R':
                        # RE-R: Each pixel is assigned with a random value 
                        # rainging [0, 255] / 255.
                        patch = tf.random.uniform(shape=[h_e, w_e, n_chan], 
                                                  minval=0, maxval=1, seed=rand_seed)
                    elif erasing_type == 'RE-M':
                        # RE-M: All pixels are assigned with the mean ImageNet 
                        # pixel value, i.e. [125, 122, 114] / 255.
                        foo = tf.ones([h_e, w_e, 1], dtype=tf.float32)
                        patch = tf.concat([foo * 125 / 255., foo * 122 / 255., foo * 114 / 255.], axis=2)
                    elif erasing_type == 'RE-1':
                        # RE-1: All pixels assigned with 255 / 255.
                        patch = tf.ones([h_e, w_e, n_chan], dtype=tf.float32)
                    else: 
                        # RE-0: All pixels assigned with 0
                        patch = tf.zeros([h_e, w_e, n_chan], dtype=tf.float32)
                    
                    # Section of the image exactly on the top and bottom of the 
                    # Random Erased patch.
                    top = images[i][(y_e+h_e):, x_e:(x_e+w_e), :]  
                    bot = images[i][:y_e, x_e:(x_e+w_e), :]

                    mid = tf.concat([bot, patch, top], axis=0)
                
                    # Part of the image on the left of the Random Erased patch
                    left= images[i][:, :x_e, :]
                    # Part of the image on the right of the Random Erased patch
                    right= images[i][:, (x_e+w_e):, :]  
                
                    mod_img.append(tf.concat([left, mid, right], axis=1))
                    break
                
                attempts += 1  # number of attempts in order to find a valid patch position

            # At the time I am writing this notebook Tensorflow doesn't support 
            # else clause in while loop, so an extra "if" should be added to 
            # cover the case when the maximum number of attempts is excedded.
            ## Not currently allowed:
            ## else:  
            ##     mod_img.append(images[i])
            #
            # If the maximum number of attempts is excedded the image is kept 
            # unchanged.
            if attempts == max_attempts:
                mod_img.append(images[i])

    output_images = tf.reshape(tf.stack(mod_img), shape=(batch, h, w, n_chan))
    output_labels = tf.reshape(labels, shape=(batch, n_classes))
    
    return (output_images, output_labels)

### MixUp

In [None]:
def mixup(images, labels, alpha=0.4, batch=BATCH_SIZE, rand_seed=SEED):
    """
    Perform MixUp over a batch of images, similar to the algorithm described in 
    "https://arxiv.org/pdf/1710.09412.pdf".
    NOTE: in the paper some details such as how to choose the images for 
    mixing-up are not explained, so an approach similar to CutMix was choisen in 
    this implementation.
    This function is meant to be used with TPUs.
    
    Args:
        images: a batch of images, i.e. a tensor of size [BATCH_SIZE, WIDE, HEIGH, N_CHANNELS]. 
                Pixels should be normalized in [0, 1].
        labels: a batch of tf tensor, image labels using one-hot encoding
        alpha: non negative float, parameters of a Beta(alpha, alpha) distribution
        batch: int, batch size.
        rand_seed: int, random seed.
            
    Returns:
        A tuple (batch of images, batch of labels) modified using MixUp.
    """

    # Set numpy random seed
    np.random.seed(rand_seed)
    
    # Modified images that would be reshaped in a tensor with the same dimensions 
    # as the initial image batch.
    mod_img = []
    # Modified labels
    mod_lab = []
    
    # Input dimensions   
    h = images.shape[1]  # image height
    w = images.shape[2]  # image width
    n_chan = images.shape[3]  # number of channels
    n_classes = labels.shape[1]  # number of classes
    
    # Shuffle mini-batch
    batch_shuffle = np.arange(batch) 
    np.random.shuffle(batch_shuffle)

    for i, j in enumerate(batch_shuffle):
        lamb = np.random.beta(alpha, alpha)  # beta distribution using numpy
        
        # New image
        new_img = lamb * images[i] + (1 - lamb) * images[j]
        mod_img.append(new_img)
        
        # New "mixed" label
        new_lab = lamb * labels[i] + (1 - lamb) * labels[j]
        mod_lab.append(new_lab) 

    output_images = tf.reshape(tf.stack(mod_img), shape=(batch, h, w, n_chan))
    output_labels = tf.reshape(tf.stack(mod_lab), shape=(batch, n_classes))
    
    return (output_images, output_labels)

### CutMix

In [None]:
def cutmix(images, labels, batch=BATCH_SIZE, rand_seed=SEED):
    """
    Perform CutMix over a batch of images, following the original description in 
    "https://arxiv.org/abs/1905.04899".
    This function is meant to be used with TPUs.
    
    Args:
        images: a batch of images, i.e. a tensor of size [BATCH_SIZE, WIDE, HEIGH, N_CHANNELS]. 
                Pixels should be normalized in [0, 1].
        labels: a batch of tf tensor, image labels using one-hot encoding
        batch: integer, batch size
        rand_seed: random seed
            
    Returns:
        A tuple (batch of images, batch of labels) modified using CutMix.
    """
    # Set numpy random seed
    np.random.seed(rand_seed)
    
    # Modified images that would be reshaped in a tensor with the same 
    # dimensions as the initial image batch.
    mod_img = []
    # Modified labels
    mod_lab = []

    # Input dimensions
    h = images.shape[1]  # image height
    w = images.shape[2]  # image width
    n_chan = images.shape[3]  # number of channels
    n_classes = labels.shape[1]  # number of classes
    
    # Shuffle mini-batch
    batch_shuffle = np.arange(batch)
    np.random.shuffle(batch_shuffle)
    
    for i, j in enumerate(batch_shuffle):
        lamb = np.random.uniform()
        
        # Coordinates of a random point inside the image, "center" of the patch
        r_x = np.random.randint(0, w)
        r_y = np.random.randint(0, h)
        
        r_w = int((1 - lamb)**0.5 * w)  # patch width
        r_h = int((1 - lamb)**0.5 * h)  # patch height
        
        # Bottom-left corner of the patch
        x_1 = int(np.max([r_x - r_w / 2, 0]))
        y_1 = int(np.max([r_y - r_h / 2, 0]))
        # Top-right corner of the patch
        x_2 = int(np.min([r_x + r_w / 2, w]))
        y_2 = int(np.min([r_y + r_h / 2, h]))
        
        patch = images[j][y_1:y_2, x_1:x_2, :]
        
        # Sections of the image exactly on the top and bottom of the  patch
        top = images[i][y_2:, x_1:x_2, :]
        bot = images[i][:y_1, x_1:x_2, :]   

        mid = tf.concat([bot, patch, top], axis=0)

        # Sections of the image exactly on the left and right of the  patch
        left= images[i][:, :x_1, :]
        right= images[i][:, x_2:, :]
        
        mod_img.append(tf.concat([left, mid, right], axis=1))

        # Real lambda coefficient applied, i.e. 1 - [(area patch) / (area image)]
        lamb = 1 - (x_2 - x_1) * (y_2 - y_1) / (w * h)
        
        # New "mixed" label
        new_lab = lamb * labels[i] + (1 - lamb) * labels[j]
        
        mod_lab.append(new_lab)
        
    output_images = tf.reshape(tf.stack(mod_img), shape=(batch, h, w, n_chan))
    output_labels = tf.reshape(tf.stack(mod_lab), shape=(batch, n_classes))
        
    return (output_images, output_labels)

# Data Augmentation Techniques Visual Comparison

In [None]:
# Get a batch of images to display
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)
batch = next(train_batch)
images = batch[0]
labels = batch[1]

### No augmentation

In [None]:
display_batch_of_images((images, labels))

### Random erasing

In the Random Erasing/CutOut Method the original labels are preserved. 

Here I have used patches filled with random noise but there are other options used in the original article: the mean ImageNet pixel values, all ones (or 255s) and all zeros. My function implementation can handle all these options using the parameter *erasing_type*.

In [None]:
display_batch_of_images((random_erasing(images, labels, erasing_type='RE-R', batch=20)))

### MixUp

The MixUp method modifies both images and labels. In particular, non-integer labels are allowed: each sample can have a label in the original fomat (one-hot encoding) or a vector of 104 entries in which there are only two entries different from zero that sum to one.

For example, "0.336 petunia 0.664 water lily" means that the augmented images is generated throught the folmula $\lambda \cdot images_{i} + (1 - \lambda) \cdot images_{j}$ where $\lambda = 0.336$ and $images_{i}$ belongs to class "petunia" while $images_{j}$ belongs to class "water lily".

In [None]:
display_batch_of_images((mixup(images, labels, batch=20)))

### CutMix

The CutMix method modifies both images and labels. In particular, non-integer labels are allowed: each sample can have a label in the original fomat (one-hot encoding) or a vector of 104 entries in which there are only two entries different from zero that sum to one.

For example, "0.336 petunia 0.664 water lily" means that the augmented images is generated from two images belonging to classes "petunia" and "water lily" respectively. The area from the "petunia" original image cover the 33.6% of the area of the aumented image and the area from the "water lily" original image cover the remaining 66.4% of the area of the aumented image.

In [None]:
display_batch_of_images((cutmix(images, labels, batch=20)))

# Model

To experiment the three method a simple EfficientNetB0 with Global Average Pooling applied to the output of the last convolutional layer was chosen. This choice was made because EfficientNet architecture is becoming very popular, it has a [Keras implementation](https://keras.io/api/applications/efficientnet/) and in the "B0" version is quite fast to train.

In [None]:
def get_model(model_name=''):
    """
    Utility function to create a model based on EfficientNetB0 architecture.
    Args:
        model_name: string, if empty string imagenet weights are loaded and all
                    layers are trainable. Otherwise, the weights in the 
                    model_name.h5 file are loaded (for predicting/evaluating).
    """
    backbone = efn.EfficientNetB0(weights='imagenet', 
                                  include_top=False, 
                                  pooling='avg', 
                                  input_shape=[*IMAGE_SIZE, N_CHANNELS])
    if model_name == '':
        backbone.trainable = True  # All layers of the backbone net are trainable 
        
    model = tf.keras.Sequential([backbone,
                                 tf.keras.layers.Dense(len(CLASSES), 
                                                       activation='softmax', 
                                                       dtype='float32')])
    if model_name != '':
        model.load_weights(model_name)
    
    return model

In [None]:
with strategy.scope():
    model = get_model()
        
model.compile(optimizer='adam',
              loss = 'categorical_crossentropy',
              metrics=['categorical_accuracy'])

model.summary()

In [None]:
def compile_and_train(augmentation='nothing', iterations=N_ENSEMBLE):
    """
    Function for compiling and training a model based on EfficientNetB0 
    architecture.
    Args:
        augmentation: string, type of agumentation to perform among "cutmix",
                      "mixup" and "random_erasing". Otherwise no data 
                      augmentation is performed.
        iterations: positive int, number of models with the same augmentations 
                    to be trained.
    Returns:
         A list of all the traininig histories, one for each iteration.
         Note that at each iteration the best model in terms of 
         val_categorical_accuracy is saved in the output directory.
    """
    
    # Garbage collector
    gc.collect()
    
    # Release TPU memory
    tf.tpu.experimental.initialize_tpu_system(tpu)
    
    # List of all the training histories, one for each iteration 
    history_collection = []
    
    for i in range(iterations):
        
        print('Model ' + str(i) + '/' + str(iterations - 1))
        model_name = 'effb0_' + augmentation + '_' + str(i)
        
        # Model definition
        with strategy.scope():   
            model = get_model()
    
        # Compile the model
        model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['categorical_accuracy'])
    
        # Callbacks
        checkpoint = tf.keras.callbacks.ModelCheckpoint(model_name + '.h5',
                                                        save_best_only=True,
                                                        save_weights_only=True,
                                                        #monitor='val_categorical_accuracy',  # alternative: val_loss
                                                        monitor='val_loss',
                                                        mode='auto',  # i.e. "max" for accuracy
                                                        save_freq='epoch',
                                                        verbose=0)
    
        # Training
        history = model.fit(get_training_dataset(augmentation), 
                            steps_per_epoch=STEPS_PER_EPOCH, 
                            epochs=EPOCHS,
                            validation_data=get_validation_dataset(),
                            callbacks=[checkpoint])
        
        history_collection.append(history)
        
        print()
    
    return history_collection

# Training

In [None]:
hist_no_augmentation = compile_and_train()

In [None]:
hist_random_erasing = compile_and_train(augmentation='random_erasing')

In [None]:
hist_mixup = compile_and_train(augmentation='mixup')

In [None]:
hist_cutmix = compile_and_train(augmentation='cutmix')

# Predictions

The accuracy is calculated on an esemble of N_ENSEMBLE models (averaging the predictions) in order to obtain results that are more reliable and robust than a single shot.

In [None]:
# Names list of all the available models
models = [f for f in listdir('/kaggle/working/') if isfile(join('/kaggle/working/', f)) and f[-2:]=='h5']

In [None]:
# Accuracy on the test set 
accuracies = {}

for aug in ['nothing', 'random_erasing', 'mixup', 'cutmix']:
    
    models_pred = []
    for i in range(N_ENSEMBLE):
        
        print(aug, i)
        
        gc.collect()  # garbage collector
        
        # Release TPU memory
        tf.tpu.experimental.initialize_tpu_system(tpu)
        
        
        weights = '_'.join([MODEL_NAME, aug, str(i)]) +'.h5'
        
        with strategy.scope():
            model = get_model(weights)

        # Compile the model
        model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['categorical_accuracy'])
        # Predictions
        preds = model.predict(get_test_dataset(labeled=True, ordered=True))
        models_pred.append(preds)
    
    # Ensemble: averaging predictions
    avg_pred = sum(models_pred) / len (models_pred)
    avg_pred = [x.argmax() for x in avg_pred]
    
    # Ordered Test labels
    labels_test = np.concatenate([y for x, y in get_test_dataset(labeled=True, ordered=True)], axis=0)
    labels_test = [x.argmax() for x in labels_test]
    
    # Accuracy calculation
    accuracies[aug] = accuracy(avg_pred, labels_test)
    
    print()

# Analysis

### Graph 1: Accuracy

In [None]:
fig, ax = plt.subplots()
bar_graph = ax.bar(accuracies.keys(), accuracies.values(), width=0.6, color=['grey', 'blue', 'blue', 'blue'])
ax.set_ylabel('Accuracy', fontsize=16)
ax.set_title("Ensemble Models' Accuracy", fontsize=20)
locs, labels = plt.yticks()
plt.yticks(np.arange(0, 1.01, step=0.1))
plt.xticks(np.arange(4), ('NoAugmentation', 'RandomErasing', 'MixUp', 'CutMix'), fontsize=16)
autolabel(bar_graph)

As we can see in the above graph, all the three data augmentation techniques generates an improvement in accuracy in this example. 

An ensemble of N_ENSEMBLE models (section "Parameters") is used to improve result robustness (only for this graph!).

The improvement using CutMix seems to be quite significant with respect the non-augmented model.

****

### Graph 2: Model Loss - Training Set vs Validation Set

In [None]:
plot_metric(hist_no_augmentation[0].history, hist_random_erasing[0].history, hist_mixup[0].history, hist_cutmix[0].history, metric='loss')

Using MixUp and CutMix the Loss on the Validation Set results lower than the loss on the Training Set. This is apparently a counterintuitive result that can be explained since the examples in the Validation Set are not augmented, so they are easyer to classify correctly.

Notice that both MixUp and CutMix validation curves are smoother than the non-augmented and the Random Erasing cases.

### Graph 3: Model Accuracy - Training Set vs Validation Set

In [None]:
plot_metric(hist_no_augmentation[0].history, hist_random_erasing[0].history, hist_mixup[0].history, hist_cutmix[0].history, metric='categorical_accuracy')

Using CutMix the accuracy on the Validation Set results lower than the accuracy on the Training Set. As before, it can be explained since the examples in the Validation Set are not augmented, so they are easyer to classify correctly.

In the MixUp case, both Training and Validation accuracy curves are very close. I would expected a behaviour similar to CutMix and I do not know why there is this difference (Any idea?:))

****

### Graph 4: Model Loss and Error - No Augmentation vs Random Erasing vs MixUp vs CutMix

In [None]:
 plot_loss_error_all(hist_no_augmentation[0].history, hist_random_erasing[0].history, hist_mixup[0].history, hist_cutmix[0].history, set_type='validation')

The introduction of augmentation techniques seem to be beneficial on the Validation Set starting from epoch 5 in this example.

****

# Observations

* As I stated in the first paragraph, all the above code is intended only for giving a glimpse about the usefulness of Random Erasing, MixUp and CutMix in the context of image recognition throught some examples.


* From the graph "Ensemble Models' Accuracy" we can notice an increasing accuracy from the models trained without augmentation to the models trained using CutMix. All the experimented augmentation methods showed some improvements with respect the basic model without aumentation. In our example, Random Erasing increase accuracy of about 0.5 percentage point (p.p.), MixUp of about 1 p.p. and finally CutMix of about 2 p.p. Similar results are reported in [\[1\]](https://arxiv.org/abs/1708.04896) ("Related Articles and Links") on an ImageNet classification problem.


* Another intresting evidence about the usefulness of these aumentation methods, at least in our specific example, is given by the "Model Loss" / "Model Error" graphs using teh Validation Set:
    * Random Erasing's Loss and Top-1 Error is only slightly lower that the one of the non augmented model on average.
    * MixUp's Loss is substantially lower than the non augmented one starting from 10th epoch (about 0.4 vs about 0.6).
    * CutMix shows both the lowest Loss and Top-1 Error
    
    
* All these methods can be used combined. In particular, a common combo is CutMix-MixUp.


* There are several improvements that can be done to this notebook, for istance:
    * Repeat the analysis using different parameters for each augmnetation methods
    * Replicate the class activation mapping (CAM) analysis as in [\[1\]](https://arxiv.org/abs/1708.04896)
    * Use different datasets to see if different results are obtained
    * Try to implement [AugMix](https://arxiv.org/abs/1912.02781)
    * ...

****

# Related Articles and Links

1. [Random Erasing Data Augmentation](https://arxiv.org/abs/1708.04896) was the article used for the Random Erasing implementation in this notebook. Another interesting article that describe basically the same method (called CutOut) is: [Improved Regularization of Convolutional Neural Networks with Cutout](https://arxiv.org/abs/1708.04552)
2. [Mixup: Beyond Empirical Risk Minimization](https://arxiv.org/abs/1710.09412).
3. [CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features](https://arxiv.org/abs/1905.04899).
4. [AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty](https://arxiv.org/abs/1912.02781) explains the brand new AugMix method that could be implemented and included in the notebook.

***