# Intro

In this notebook, we integrate;
- Albumentations augmentations
- CutMix augmentation
- TFRecords
- Multi-GPU pipeline

These implementations will speed up training and allow more opportunities for generalization.

# SEED Everything

In [None]:
import tensorflow as tf
import os
import numpy as np
import random

SEED = 42

def set_seeds(seed=SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)

def set_global_determinism(seed=SEED):
    set_seeds(seed=seed)

    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'
    
    tf.config.threading.set_inter_op_parallelism_threads(1)
    tf.config.threading.set_intra_op_parallelism_threads(1)

    print("Random seed initialized.")

set_global_determinism(seed=SEED)

In [None]:
import pandas as pd
import json, cv2, re, math, ast
import seaborn as sns
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
sns.set(style='darkgrid')

from sklearn.model_selection import KFold
from tensorflow.keras.models import Model, load_model, Sequential
from tensorflow.keras.layers import Dense, Dropout, GlobalAvgPool2D, Input, BatchNormalization
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, LearningRateScheduler, ModelCheckpoint, CSVLogger, TensorBoard, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.applications import *
from tensorflow.keras.mixed_precision import Policy, set_global_policy, LossScaleOptimizer
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.client import device_lib
import tensorflow_addons as tfa

import albumentations as A
from functools import partial

# Filter all Tensorflow logs except FATAL errors
tf.get_logger().setLevel('FATAL') #DEBUG,ERROR,FATAL,INFO,WARN

# Mixed Precision
set_global_policy(Policy('mixed_float16'))

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.experimental.list_logical_devices('GPU')
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)

# Mirrored Strategy
strategy = tf.distribute.MirroredStrategy(
    devices=["/gpu:0", "/gpu:1"], 
    cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()
)

# Initialize paths
AUTOTUNE = tf.data.experimental.AUTOTUNE
PATH = '../input/plant-pathology-2021-fgvc8'
ns_weights = '../input/keras-efficientnetb3-noisy-student/noisy_student_efficientnet-b1.h5'

for x in device_lib.list_local_devices():
    if x.device_type == 'GPU':
        print(x.physical_device_desc)

Due to the compute capability of Kaggle's in-built GPUs, we will not be able to;
1. Activate mixed precision
2. Increase our batch size per replica

In [None]:
# Initialize variables
REPLICAS = strategy.num_replicas_in_sync
BATCH_SIZE_PER_REPLICA = 16
BUFFER_SIZE = 512
IMAGE_SIZE = [224, 224]
n_train_augments = 3
CLASSES = 12
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * REPLICAS
print('Number of replicas:', REPLICAS)
print('Global batch size:', GLOBAL_BATCH_SIZE)

Images from both 2020 and 2021 training sets have been merged into train.csv.

In [None]:
train = pd.read_csv(r'../input/train-data2/train.csv')
IMG_PATH = r'../input/plant-pathology-2021-fgvc8/train_images'
train

In [None]:
sns.countplot(y=train['labels'], order=train['labels'].value_counts().index)
plt.show()

# Duplicate Detection

## Image hashing

In [None]:
# funcs = [
#         imagehash.average_hash,
#         imagehash.phash,
#         imagehash.dhash,
#         imagehash.whash,
#     ]
# image_ids = []
# hashes = []

# for path in tqdm(train['image_id'], desc='Hashing images'):
#     image = Image.open(os.path.join(IMG_PATH, path))
#     image_id = os.path.basename(path)
#     image_ids.append(image_id)
#     hashes.append(np.array([f(image).hash for f in funcs]).reshape(256))

# hashes_all = np.array(hashes)
# hashes_all = torch.Tensor(hashes_all.astype(int))
# sims = np.array([(hashes_all[i] == hashes_all).sum(dim=1).numpy()/256 for i in tqdm(range(hashes_all.shape[0]), desc='Calculating similarities')])

# indices1 = np.where(sims > 0.9)
# indices2 = np.where(indices1[0] != indices1[1])
# image_ids1 = [image_ids[i] for i in indices1[0][indices2]]
# image_ids2 = [image_ids[i] for i in indices1[1][indices2]]
# dups = {tuple(sorted([image_id1,image_id2])):True for image_id1, image_id2 in zip(image_ids1, image_ids2)}
# duplicate_image_ids = sorted(list(dups))
# print('Found %d duplicates' % len(duplicate_image_ids))

# # Remove duplicates from external data
# imgs_to_remove = [x[1] for x in duplicate_image_ids]

# duplicates = pd.DataFrame(duplicate_image_ids)
# duplicates.columns = ['image0', 'image1']

duplicates = pd.read_csv('../input/train-data/duplicates.csv')
duplicates

In [None]:
nrows = 5; ncols=2
fig, axes = plt.subplots(nrows, ncols, figsize=(6*ncols, 4*nrows))
for i, row in enumerate(duplicates.to_numpy()[-nrows:]):
    image0 = cv2.imread(os.path.join(IMG_PATH, row[0]))
    axes[i][0].axis('off')
    axes[i][0].set_title(row[0])
    axes[i][0].imshow(image0)

    image1 = cv2.imread(os.path.join(IMG_PATH, row[1]))
    axes[i][1].axis('off')
    axes[i][1].set_title(row[1])
    axes[i][1].imshow(image1)

plt.tight_layout()
plt.show()

In [None]:
train[train['image_id'].isin(duplicates['image1'].tolist())]['labels'].value_counts()

In [None]:
train = train[~train['image_id'].isin(duplicates['image1'].tolist())]
train

# Reshape Resize

In [None]:
# def image_shape(x):
#     return cv2.imread(os.path.join(IMG_PATH, x)).shape

# tqdm.pandas(desc='Getting shapes')
# train['shape'] = train['image_id'].progress_apply(image_shape)

train = pd.read_csv('../input/train-data/train_shape.csv')
train['shape'] = train['shape'].apply(ast.literal_eval)
train['shape'].value_counts()

In [None]:
def image_resize(x):
    image = cv2.imread(os.path.join(IMG_PATH, x))
    if image.shape[0] > image.shape[1]:
        image = cv2.transpose(image)
    else:
        image = image   
    return cv2.resize(image, dsize=(512, 512))

image = image_resize(train['image_id'][5000])
plt.axis('off')
plt.imshow(image)
plt.title(image.shape)
plt.show()

trainV1.csv is the finalized dataset with duplicates removed.

In [None]:
train = train[['image_id', 'labels']]

# def image_shape(x):
#     return cv2.imread(os.path.join(SAVE_PATH, x)).shape

# tqdm.pandas(desc='Getting shapes')
# train['shape'] = train['image_id'].progress_apply(image_shape)

train = pd.read_csv(r"../input/train-data/trainV1.csv")
train['shape'].value_counts()

In [None]:
# Label encode
train['labels_codes'] = pd.Categorical(train['labels']).codes
pd.DataFrame({"Categories": pd.Categorical(train['labels']).categories})

# Convert images to TFREC shards

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(feature0, feature1, feature2):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  # Create a dictionary mapping the feature name to the tf.train.Example-compatible
  # data type.
  feature = {
      'image': _bytes_feature(feature0),
      'target': _int64_feature(feature1),
      'image_name': _bytes_feature(feature2),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
# Obtain the multiple of training set size for even splitting
fold_list = []
print('Possible number of folds:')
for n in range(1,30):
    if np.divmod(train.shape[0], n)[1]==0:
        print('{} folds - {} images per fold'.format(n, int(train.shape[0]/n)))
        fold_list.append(n)

In [None]:
# N_FILES = fold_list[1]
# IMG_QUALITY = 96 # Avoiding 100% Quality as it increases file size
# IMAGE_SIZE = (512, 512)
# train['shard'] = 0

# # Stratify the shards
# skf = StratifiedKFold(n_splits=N_FILES, shuffle=True, random_state=42)
# for fold, (train_idx, val_idx) in enumerate(skf.split(train, train['labels'])):
#     train.loc[val_idx, 'shard'] = fold

# # Rewrite the TFRecords after stratification and resize
# for tfrec_num in tqdm(range(N_FILES), desc='Writing TFRecords'):
#     samples = train[train['shard'] == tfrec_num]
#     n_samples = len(samples)
#     with tf.io.TFRecordWriter('train-%.2i-%i.tfrec'%(tfrec_num, n_samples)) as writer:
#         for row in tqdm(samples.itertuples(), desc=('Fold {}'.format(tfrec_num+1)), total=int(train.shape[0]/N_FILES)):
#             label = row.labels_codes
#             image_name = row.image_id
#             img = image_resize(image_name)
#             img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, IMG_QUALITY))[1].tobytes()
#             example = serialize_example(img, label, str.encode(image_name))
#             writer.write(example)

# Loading tfrec shards

In [None]:
TFR_PATH = '../input/plant-pathology-20202021-tfrec'
FILENAMES = tf.io.gfile.glob(TFR_PATH + "/*.tfrec")
print("Number of shards:", len(FILENAMES))

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)/255.0
    image = tf.image.resize(image, size=IMAGE_SIZE, method='nearest')
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example, labeled):
    labeled_map = {"image": tf.io.FixedLenFeature([], tf.string),
                    "target": tf.io.FixedLenFeature([], tf.int64)}
    unlabeled_map = {"image": tf.io.FixedLenFeature([], tf.string)}

    tfrecord_format = (labeled_map if labeled else unlabeled_map)
    
    example = tf.io.parse_single_example(serialized=example,
                                  features=tfrecord_format)

    image = decode_image(example["image"])

    if labeled:
        label = tf.cast(example['target'], tf.int32)
        return image, label
    return image

def load_dataset(filenames, labeled=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled),
        num_parallel_calls=AUTOTUNE
        )
    return dataset

def one_hot(image, label, CLASSES=CLASSES):
    return image, tf.one_hot(indices=label, 
                             depth=CLASSES,
                             dtype=tf.float32)

def count_data_items(filenames):
    """Obtaining total number of images in dataset from
    the tfrecord shards."""
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
         for filename in filenames]
    return np.sum(n)

# Augmentation

We emphasize on the duplication of our augments to increase image counts. This is replicative of Keras' ImageDataGenerator API.

For demonstration purposes, only basic augmentations were performed.

In [None]:
def album_augment(image, label):
    transforms = A.Compose([
        A.RandomResizedCrop(*IMAGE_SIZE, p=1),
        # A.Transpose(p=1),
        A.ShiftScaleRotate(
            shift_limit=0.0625,
            scale_limit=0.1,
            rotate_limit=45,
            p=1),
        # A.RandomBrightnessContrast(
        #     brightness_limit=0.2,
        #     constrast_limit=0.2,
        #     p=1
        # ),
        # A.HueSaturationValue(
        #     hue_shift_limit=20,
        #     sat_shift_limit=30,
        #     val_shift_limit=20,
        #     p=1),
        A.Flip(p=1),
        A.Cutout(num_holes=8,
            max_h_size=8,
            max_w_size=8,
            p=1)
        ], p=1)

    images = [image]
    labels = [label]

    for _ in range(0, n_train_augments):
        aug_image = transforms(image=image)['image']
        aug_image = tf.cast(x=aug_image, dtype=tf.float32)
        images.append(aug_image)
        labels.append(label)

    return images, labels    

def train_augment(image, label):
    aug_func = tf.numpy_function(func=album_augment, inp=[image, label], Tout=[tf.float32, tf.int32])
    return aug_func

In [None]:
def cutmix(image, label, PROBABILITY=0.5):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    DIM = IMAGE_SIZE[0]

    imgs = []; labs = []
    for j in range(GLOBAL_BATCH_SIZE):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,GLOBAL_BATCH_SIZE),tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM,y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM,x+WIDTH//2)
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast(WIDTH*WIDTH/DIM/DIM,tf.float32)

        lab1 = label[j,]
        lab2 = label[k,]
        
        labs.append((1-a)*lab1 + a*lab2)
            
    image2 = tf.reshape(tf.stack(imgs),(GLOBAL_BATCH_SIZE,DIM,DIM,3))
    label2 = tf.reshape(tf.stack(labs),(GLOBAL_BATCH_SIZE,CLASSES))
    return image2,label2

In [None]:
def get_train_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.map(partial(train_augment), num_parallel_calls=AUTOTUNE)
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(BUFFER_SIZE*n_train_augments)
    dataset = dataset.repeat()
    dataset = dataset.map(one_hot, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(GLOBAL_BATCH_SIZE)
    dataset = dataset.map(cutmix, num_parallel_calls=AUTOTUNE)
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.batch(GLOBAL_BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

In [None]:
def augm_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.map(partial(train_augment), num_parallel_calls=AUTOTUNE)
    # dataset = dataset.batch(BATCH_SIZE_PER_REPLICA)
    return dataset

## Validation set

In [None]:
"""repeat() required to increase validation_steps"""
def get_val_dataset(filenames, labeled=True):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.map(one_hot, num_parallel_calls=AUTOTUNE)
    dataset = dataset.shuffle(BUFFER_SIZE)
    dataset = dataset.batch(GLOBAL_BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

## Augment Visualization

In [None]:
def display_image(dataset):
    dataset = dataset.unbatch().batch(16)
    images, labels = next(iter(dataset))
    col = 4; row = 4
    plt.figure(figsize=(10, 10))
    for i in range(row*col):
        plt.subplot(row, col, i+1)
        plt.axis('off')
        image = tf.image.convert_image_dtype(images[i], 'uint8')
        plt.imshow(image)

    plt.tight_layout()
    plt.show()

### Albumentations Augmentation

In [None]:
display_image(augm_dataset(FILENAMES))

### Albumentations + CutMix Augmentation

In [None]:
display_image(get_train_dataset(FILENAMES))

### Raw images

In [None]:
display_image(get_val_dataset(FILENAMES))

# Model training

Macro-averaging F1-score is used here as it is stricter on imbalanced datasets.

### Wrapper Functions

In [None]:
"""Set from_logits=False when using Dense Softmax output layer"""
def create_model(model, learning_rate, dropout_rate, 
label_smoothing, weights):
    with strategy.scope():
        inputs = Input(shape=(*IMAGE_SIZE, 3)) 
        base_model = model(
                include_top=False, 
                weights=weights,
                pooling='avg',
                )(inputs)

        hidden = Dropout(dropout_rate)(base_model)
        outputs = Dense(12, activation='softmax', dtype='float32')(hidden)
        model = Model(inputs=inputs, outputs=outputs)

        model.compile(
            optimizer=LossScaleOptimizer(
                Adam(learning_rate=learning_rate)),
            loss=CategoricalCrossentropy(
                from_logits=False, 
                label_smoothing=label_smoothing),
            metrics=[
                'categorical_accuracy',
                tfa.metrics.F1Score(
                    num_classes=CLASSES,
                    average='macro',
                    name='f1_score'
                )]
            )
        return model

def model_training(model, model_name, weights, learning_rate, 
  dropout_rate, label_smoothing, fold, n_splits, epochs):
    cv = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)

    lr_schedule = LearningRateScheduler(ExponentialDecay(
        initial_learning_rate=1e-4,
        decay_steps=2,
        decay_rate=0.5,
        staircase=True), verbose=1)

    # lr_reduce = ReduceLROnPlateau(
    #     monitor='val_categorical_accuracy',
    #     factor=0.9,
    #     patience=1,
    #     verbose=1,
    #     mode='max',
    #     cooldown=2,
    #     min_lr=1e-9
    # )

    # early_stopping = EarlyStopping(
    #     monitor='val_categorical_accuracy',
    #     patience=10,
    #     verbose=1,
    #     mode='max',
    #     restore_best_weights=False
    # )

    for i, (train_idx, valid_idx) in enumerate(cv.split(FILENAMES)): 
        if i == (fold-1):
            model_name = model_name
            filepath = model_name + '-fold' + str(i+1) + '-epoch{epoch:02d}-f1{f1_score:.4f}-valf1{val_f1_score:.4f}.h5'
            model_save = ModelCheckpoint(
                filepath=filepath,
                save_best_only=False,
                save_freq='epoch',
                save_weights_only=True,
                monitor='val_f1_score',
                mode='max',
                verbose=1)

            filename = model_name + '-fold' + str(i+1) + '.csv'
            csv_logger = CSVLogger(
                filename=filename, 
                separator=',', 
                append=False)

            TRAIN_FILENAMES = [FILENAMES[id] for id in train_idx]
            VALID_FILENAMES = [FILENAMES[id] for id in valid_idx]
            
            N_TRAIN_IMAGES = count_data_items(TRAIN_FILENAMES)
            N_VALID_IMAGES = count_data_items(VALID_FILENAMES)

            print('Number of training images:', N_TRAIN_IMAGES)
            print('Number of validation images:', N_VALID_IMAGES)

            print(f'\n-------------- FOLD {i+1}/{n_splits} --------------')

            train = get_train_dataset(TRAIN_FILENAMES)
            val = get_val_dataset(VALID_FILENAMES)

            model = create_model(
                model=model,
                weights=weights,
                dropout_rate=dropout_rate,
                learning_rate=learning_rate,
                label_smoothing=label_smoothing
            )
            
            history = model.fit(
                x=train,
                verbose=1,
                epochs=epochs,
                validation_data=val,
                batch_size=GLOBAL_BATCH_SIZE,
                steps_per_epoch=(N_TRAIN_IMAGES//GLOBAL_BATCH_SIZE),
                validation_steps=(N_VALID_IMAGES//GLOBAL_BATCH_SIZE),
                shuffle=True,
                callbacks=[
                    # lr_reduce,
                    lr_schedule,
                    # early_stopping,
                    model_save,
                    csv_logger]
                )     

### Learning Schedule

Other learning schedules can be implemented here, like cosine annealling with warm-up phase or simply using Tensorflow's in-built ReduceOnLRPlateau callback.

In [None]:
lr_schedule = ExponentialDecay(
    initial_learning_rate=1e-4,
    decay_steps=2,
    decay_rate=0.5,
    staircase=True)

epochs = 35

y = [lr_schedule(x) for x in range(epochs)]

for epoch, lr in enumerate(y):
  if epoch==0 or epoch==(len(y)-1):
    print(epoch, lr)

plt.figure(figsize=(10,5))
plt.plot(range(epochs), y)
plt.xticks(range(epochs))
plt.show()

### EfficientNet

For demonstration purposes, we will only train for 10 epochs without any dropout or label smoothing incorporated.

In [None]:
model_training(
    model=EfficientNetB0,
    model_name='EFNB0',
    weights='imagenet',
    learning_rate=1e-4,
    dropout_rate=0,
    label_smoothing=0,
    n_splits=5,
    fold=5,
    epochs=10)

In [None]:
history = pd.read_csv('./EFNB0-fold5.csv')
history

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,5))

sns.lineplot(
    x='epoch', 
    y='loss', 
    data=history, 
    label='loss', 
    ax=ax1)

sns.lineplot(
    x='epoch', 
    y='val_loss', 
    data=history, 
    label='val_loss',
    ax=ax1)

sns.lineplot(
    x='epoch',
    y='f1_score',
    data=history,
    label='f1_score',
    ax=ax2)

sns.lineplot(
    x='epoch', 
    y='val_f1_score',
    data=history,
    label='val_f1_score',
    ax=ax2)

ax1.set_ylabel('loss')
ax2.set_ylabel('score')
plt.show()

# Test time Augmentation

To specify n_test_augments and insert your augmentation functions for TTA.

In [None]:
def test_transform(image, label):
    transforms = A.Compose([
        A.CenterCrop(*IMAGE_SIZE),
        A.Resize(
            *IMAGE_SIZE, 
            interpolation=cv2.INTER_CUBIC)
    ])

    images = [image]
    labels = [label]

    for _ in range(0, n_test_augments):
        aug_image = transforms(image=image)['image']
        aug_image = tf.cast(x=aug_image, dtype=tf.float32)
        images.append(aug_image)
        labels.append(label)

    return images, labels

def test_augment(image, label):
    aug_image = tf.numpy_function(func=test_transform, inp=[image, label], Tout=[tf.float32, tf.int32])
    return aug_image, label

def get_test_dataset(filenames, labeled=False):
    dataset = load_dataset(filenames, labeled=labeled)
    dataset = dataset.map(partial(test_augment))
    dataset = dataset.repeat()
    dataset = dataset.batch(GLOBAL_BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

# Findings

- Increasing n_train_augments to about 6 or 7 will reduce generalization difference
- CutMix augmentation tends to cause model to underfit, reducing n_train_augments might help
- The raw images have resolutions of up to 2000 by 4000 pixels, using larger image sizes like 1024 might increase scores but requires higher GPU memory