In [None]:
import numpy as np
import pandas as pd
import os
import re
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.activations import softmax, relu
from tensorflow.keras.layers import Dense, BatchNormalization, Activation
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.metrics import categorical_accuracy
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import SGD, Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.layers import Input, Conv2D, MaxPool2D, Dropout, \
    GlobalAveragePooling2D
from tensorflow.python.keras.layers import RandomRotation, RandomFlip, Add, GlobalAvgPool2D
from tensorflow.keras.losses import categorical_crossentropy, sparse_categorical_crossentropy
from tensorflow.data import Dataset, TFRecordDataset
from kaggle_datasets import KaggleDatasets

print("Using TensorFlow version %s" % tf.__version__)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
dataset_base_folder = KaggleDatasets().get_gcs_path("cassava-leaf-disease-classification")
dataset_train_folder = f"{dataset_base_folder}/train"
dataset_val_folder = f"{dataset_base_folder}/val"
destination_classes = [str(i) for i in range(5)]
original_train_data_folder = f"{dataset_base_folder}/train_images"
csv_file = f"{dataset_base_folder}/train.csv"
original_train_tfrecs_folder = f'{dataset_base_folder}/train_tfrecords'

cache_train = f'{dataset_base_folder}/dataset_cache/train'
cache_test = f'{dataset_base_folder}/dataset_cache/val'


def compute_class_images_count(base_folder: str, class_name: str):
    return sum((1 for _ in os.listdir(f'{base_folder}/{class_name}')))


def compute_all_classes_images_count(base_folder: str):
    return sum((compute_class_images_count(base_folder, c) for c in destination_classes))


def compute_train_images_count():
    return compute_all_classes_images_count(dataset_train_folder)


def compute_val_images_count():
    return compute_all_classes_images_count(dataset_val_folder)

def count_data_items_tfrecs(filenames):
    """Count number of images in TFRecord files"""
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def compute_train_val_size(base_folder: str, split: float) -> (int, int):
    """Return the size of train and validation dataset according to split"""
    filenames = tf.io.gfile.glob(f'{base_folder}/ld_train*.tfrec')
    dataset_size = count_data_items_tfrecs(filenames)
    return (dataset_size * (1 - split), dataset_size * split)

#  **PARAMETERS**

In [None]:
LR = 1e-4
BATCH_SIZE_BASE = 32
BATCH_SIZE = BATCH_SIZE_BASE * REPLICAS
MOMENTUM = 0.8
DROP_RATE = 0.2
# TRAIN_SIZE = compute_train_images_count()
# VAL_SIZE = compute_val_images_count()
SPLIT = 0.2
TRAIN_SIZE, VAL_SIZE = compute_train_val_size(original_train_tfrecs_folder, SPLIT)
EPOCHS = 100
TARGET_SIZE = (512, 512)
OG_SIZE = (800, 600)
NB_CLASSES = 5
CHANNELS = 3
SEED = 420
NB_MODELS = 5

LOCAL_LOGS_FOLDER = f"./"
MODEL_ID = len([m for m in os.listdir(LOCAL_LOGS_FOLDER) if str(m).isnumeric()])
LOCAL_LOGS_PATH = f'{LOCAL_LOGS_FOLDER}/{MODEL_ID}'
AUTO = tf.data.experimental.AUTOTUNE

# **PREPROCESSING**

In [None]:
def create_dataset_iterator(base_folder: str, size: int, cache_folder: str):
    """Create Dataset Iterator from directory and cached on disk"""

    def inner_func():
        return ImageDataGenerator(rescale=1.0 / 255).flow_from_directory(base_folder,
                                                                         target_size=TARGET_SIZE,
                                                                         batch_size=1)

    return (Dataset.from_generator(inner_func,
                                   output_types=(tf.float32, tf.float32),
                                   output_shapes=(
                                       (1, *TARGET_SIZE, 3),
                                       (1, len(destination_classes))
                                   )
                                   )
            .take(size)
            .unbatch()
            .batch(BATCH_SIZE)
            .cache(f'{cache_folder}/cache')
            .repeat()
            .prefetch(tf.data.experimental.AUTOTUNE)
            .as_numpy_iterator()
            )


def resize_input(x):
    """Resize an image to the desired size"""
    x = tf.image.resize(x, [*TARGET_SIZE])
    x = tf.reshape(x, [*TARGET_SIZE, CHANNELS])
    return x


def decode_image(image_data) -> tf.image:
    """Decode Image of format String Feature to tf.utf8 then to float32"""
    image = tf.image.decode_jpeg(image_data, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 255.0)
    image = tf.image.resize(image, [*OG_SIZE])
    image = tf.reshape(image, [*OG_SIZE, CHANNELS])
    return image


def read_tfrecord(example, labeled=True):
    """Read a TFRecord(str, int64) and transform it into TFRecord(float32, int32)"""
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'target': tf.io.FixedLenFeature([], tf.int64),
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'image_name': tf.io.FixedLenFeature([], tf.string),
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])

    if labeled:
        print(example['target'])
        label_or_name = tf.cast(example['target'], tf.int32)
        print(label_or_name)
        # label_or_name.
    else:
        label_or_name = example['image_name']
    return image, label_or_name


def load_dataset(filenames) -> Dataset:
    """Load tfrecord files and transform them into a tf.Dataset"""
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda x: read_tfrecord(x), num_parallel_calls=AUTO)
    return dataset


# DATA AUGMENTATION
def random_apply_data_aug(dataset: Dataset) -> Dataset:
    """Randomly apply data transformation filters on dataset"""

    def flip(x: tf.Tensor) -> tf.Tensor:
        x = tf.image.random_flip_left_right(x)
        x = tf.image.random_flip_up_down(x)
        return x

    def color(x: tf.Tensor) -> tf.Tensor:
        x = tf.image.random_hue(x, 0.3, seed=SEED)
        x = tf.image.random_saturation(x, 0.6, 1.6)
        x = tf.image.random_brightness(x, 0.05)
        x = tf.image.random_contrast(x, 0.7, 1.3)
        return x

    def rotate(x: tf.Tensor) -> tf.Tensor:
        return tf.image.rot90(x, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))

    def zoom(x: tf.Tensor) -> tf.Tensor:
        return tf.image.random_crop(x, [int(TARGET_SIZE[0] * 0.8), int(TARGET_SIZE[1] * 0.8), 3])

    funcs_data_aug = [flip, color, rotate, zoom]
    for f in funcs_data_aug:
        dataset = dataset.map(
            lambda x, y: tf.cond(
                tf.random.uniform([], 0, 1) > 0.5,
                lambda: (f(x), y),
                lambda: (x, y)),
            num_parallel_calls=AUTO)
    return dataset


def create_dataset_tfrec_input(base_folder: str) -> (TFRecordDataset, TFRecordDataset):
    filenames = tf.io.gfile.glob(f'{base_folder}/ld_train*.tfrec')
    dataset = load_dataset(filenames)
    dataset = dataset.shuffle(buffer_size=2048)
    return dataset


def split_dataset(dataset: Dataset,
                  augment=True,
                  validation_only=False,
                  k_fold=0,
                  index=None):
    """
    Args:
        dataset: dataset to be splitted
        augment: apply augmentation on train dataset
        validation_only: return only the validation dataset
        k_fold: If > 0, split dataset into k folds of the same size.
        index: must be specified by a positive integer when k_fold > 0 indicating the index of the
        fold that will be used as validation dataset, all the other folds will become train dataset.

    Return: a train dataset and a validation dataset splitted from dataset by a pre-determined "split" factor
    """
    if k_fold == 0:
        dataset.shuffle(1000)
        validation_dataset = dataset.take(VAL_SIZE)
        validation_dataset = validation_dataset.map(lambda x, y: (resize_input(x), y))
        validation_dataset = validation_dataset.batch(BATCH_SIZE).repeat().prefetch(AUTO)

        if not validation_only:
            train_dataset = dataset.skip(VAL_SIZE)
            if augment:
                train_dataset = random_apply_data_aug(train_dataset)
            train_dataset = train_dataset.map(lambda x, y: (resize_input(x), y))
            train_dataset = train_dataset.batch(BATCH_SIZE).repeat().prefetch(AUTO)

            return train_dataset, validation_dataset
        else:
            return validation_dataset
    elif k_fold > 0:
        if index is None or index < 0 or index >= k_fold:
            exit("index of validation fold not specified or not valid")
        else:
            val_start = VAL_SIZE * index
            val_end = val_start + VAL_SIZE
            validation_dataset = dataset.skip(val_start)
            validation_dataset = validation_dataset.take(VAL_SIZE)
            validation_dataset = validation_dataset.map(lambda x, y: (resize_input(x), y))
            validation_dataset = validation_dataset.batch(BATCH_SIZE).repeat().prefetch(AUTO)

            if not validation_only:
                train_dataset_1 = dataset.take(val_start)
                train_dataset_2 = dataset.skip(val_end)
                train_dataset = train_dataset_1.concatenate(train_dataset_2)
                if augment:
                    train_dataset = random_apply_data_aug(train_dataset)
                train_dataset = train_dataset.map(lambda x, y: (resize_input(x), y))
                train_dataset = train_dataset.batch(BATCH_SIZE).repeat().prefetch(AUTO)

                return train_dataset, validation_dataset
            else:
                return validation_dataset

In [None]:
dataset = create_dataset_tfrec_input(original_train_tfrecs_folder)
dataset_train, dataset_test = split_dataset(dataset, augment=True)

# **DEFINE MODELS**

In [None]:
def create_base_model(add_custom_layers_func) -> Model:
    m = Sequential()

    add_custom_layers_func(m)

    m.add(Dense(NB_CLASSES, activation=softmax))

    return m


def add_resNet50(m: Sequential):
    inputs = Input(shape=(*TARGET_SIZE, 3))
    base_model = tf.keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=(*TARGET_SIZE, 3))
    base_model.trainable = False
    base_model = base_model(inputs, training=False)
    gap2d = GlobalAveragePooling2D(name='avg_pool')(base_model)
    bn = BatchNormalization(name='top_bn')(gap2d)
    dense = Dense(1000, activation=relu, name='fc1000')(bn)
#     dropout = Dropout(DROP_RATE, name='top_drop')(dense)

    model = Model(inputs, dense)
    m.add(model)
    
    
def make_resNet50():
    inputs = Input(shape=(*TARGET_SIZE, 3))
    base_model = tf.keras.applications.ResNet50V2(include_top=False, weights='imagenet', input_shape=(*TARGET_SIZE, 3))
    base_model.trainable = False
    base_model = base_model(inputs, training=False)
    gap2d = GlobalAveragePooling2D(name='avg_pool')(base_model)
    bn = BatchNormalization(name='top_bn')(gap2d)
#     dense = Dense(1000, activation=relu, name='fc1000')(bn)
    pred = Dense(NB_CLASSES, activation=softmax)(bn)
#     dropout = Dropout(DROP_RATE, name='top_drop')(dense)

    model = Model(inputs, pred)
    return model


def get_callbacks():
    callbacks = []

    save_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=f'ResNet50_best.h5',
        monitor='val_categorical_accuracy',
        mode='max',
        save_best_only=True)
    callbacks.append(save_callback)

    reduce_LR = tf.keras.callbacks.ReduceLROnPlateau(
        monitor='val_acc', factor=0.2, patience=5,
        mode='max', min_delta=0.0005, min_lr=0.000001
    )
    callbacks.append(reduce_LR)

    earlystopper = tf.keras.callbacks.EarlyStopping(
        monitor='val_categorical_accuracy',
        min_delta=0.005,
        patience=10,
        mode='max',
        restore_best_weights=True
    )
    callbacks.append(earlystopper)
    return callbacks


def train_models(m: Model, dataset_train_it, dataset_val_it, lr=None, epochs=None):
    lr = lr if LR is None else lr
    epochs = EPOCHS if epochs is None else epochs
    # m.compile(
    #     optimizer=SGD(momentum=MOMENTUM, lr=LR),
    #     loss=categorical_crossentropy,
    #     metrics=[categorical_accuracy]
    # )
    m.compile(
        optimizer=SGD(momentum=MOMENTUM, lr=lr),
        loss=sparse_categorical_crossentropy,
        metrics=[categorical_accuracy]
    )
    m.summary()

    history = m.fit(
        dataset_train_it,
        validation_data=dataset_val_it,
        steps_per_epoch=TRAIN_SIZE // BATCH_SIZE,
        validation_steps=VAL_SIZE // BATCH_SIZE,
        epochs=EPOCHS,
        callbacks=get_callbacks()
    )
    return history

def train_pretrained_model(m: Model, dataset_train_it, dataset_val_it, lr, epochs):
    history_pre_finetuning = train_models(m, dataset_train_it, dataset_val_it, lr=lr[0], epochs=epochs[0])
    m.trainable = True
    history_post_finetuning = train_models(m, dataset_train_it, dataset_val_it, lr=lr[1], epochs=epochs[1])
    history = [history_pre_finetuning, history_post_finetuning]
    return history

# **CREATE MODEL**

In [None]:
with strategy.scope():
    model = make_resNet50()
#     model = create_base_model(add_resNet50)

In [None]:
lr = [0.0005, 0.00001]
epochs = [5, 5]
history = train_pretrained_model(model, dataset_train, dataset_test, lr, epochs)

In [None]:
with strategy.scope():
    inputs = Input(shape=(*TARGET_SIZE, 3))
    
    norm_layer = tf.keras.layers.experimental.preprocessing.Normalization()
    mean = np.array([127.5] * 3)
    var = mean ** 2
    # Scale inputs to [-1, +1]
    x = norm_layer(inputs)
    norm_layer.set_weights([mean, var])
    
    base_model = tf.keras.applications.ResNet50V2(include_top=False, weights='imagenet', input_shape=(*TARGET_SIZE, 3))
    base_model.trainable = False
    
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x)  # Regularize with dropout
    outputs = Dense(NB_CLASSES, activation=softmax)(x)
    model = Model(inputs, outputs)
    
    model.summary()

In [None]:
model.compile(
    optimizer=Adam(),
    loss=sparse_categorical_crossentropy,
    metrics=[categorical_accuracy],
)

epochs = 20
model.fit(dataset_train, epochs=epochs, validation_data=dataset_test,
         steps_per_epoch=TRAIN_SIZE // BATCH_SIZE,
        validation_steps=VAL_SIZE // BATCH_SIZE,)

In [None]:
base_model.trainable = True


In [None]:
fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(16, 12))
history_df = pd.DataFrame(history.history)
history_df[['loss', 'val_loss']].plot(ax=axes[i,0])
history_df[['accuracy', 'val_accuracy']].plot(ax=axes[i,1])

models = []
with strategy.scope():
    for i in range(n_models):
        models.append(create_base_model(add_resNet50))
        
histories = []
for i, model in enumerate(models):
    print(F"Training model: {i}")
    train_dataset, validation_dataset = n_fold_dataset(augment = True, index=i)
    history = train_pretrained_model(model, train_dataset, validation_dataset)
    histories.append(history)

for i in range(n_models):
    validation_dataset = n_fold_dataset(augment = False, train=False, index=i)
    for model in models:
        model.evaluate(validation_dataset)

fig, axes = plt.subplots(nrows=5, ncols=2, figsize=(16, 12))

for i, history in enumerate(histories):
    history_df = pd.DataFrame(history.history)
    history_df[['loss', 'val_loss']].plot(ax=axes[i,0])
    history_df[['accuracy', 'val_accuracy']].plot(ax=axes[i,1])