# Model Training

## Introduction: what is EfficientNet?

EfficientNet, first introduced in [Tan and Le, 2019](https://arxiv.org/abs/1905.11946) is among the most efficient models (i.e. requiring least FLOPS for inference) that reaches State-of-the-Art accuracy on both imagenet and common image classification transfer learning tasks.

The smallest base model is similar to [MnasNet](https://arxiv.org/abs/1807.11626), which reached near-SOTA with a significantly smaller model. By introducing a heuristic way to scale the model, EfficientNet provides a family of models (B0 to B7) that represents a good combination of efficiency and accuracy on a variety of scales. Such a scaling heuristics (compound-scaling, details see [Tan and Le, 2019](https://arxiv.org/abs/1905.11946)) allows the
efficiency-oriented base model (B0) to surpass models at every scale, while avoiding extensive grid-search of hyperparameters.

## Setup and data loading

In [None]:
# Required packages
import numpy as np
import matplotlib.pyplot as plt

# tensorflow and keras imports
import tensorflow as tf
from keras import layers
from keras.applications import ResNet50

# tensorflow imports
import tensorflow_datasets as tfds

# IMG_SIZE is determined by model choice (EfficientNetB0)
IMG_SIZE = 224
BATCH_SIZE = 64

### Loading data

Here we load data from [tensorflow_datasets](https://www.tensorflow.org/datasets). Malaria dataset is provided in TFDS as [malaria](https://www.tensorflow.org/datasets/catalog/malaria). It features 27,558 images that belong to 2 classes: parasitized and uninfected.

In [None]:
# Dataset name
dataset_name = "malaria"

# Load the dataset split into train and test
(ds_train, ds_test), ds_info = tfds.load(
    dataset_name, split=["train", "test"], with_info=True, as_supervised=True
)

# Number of classes in the dataset
NUM_CLASSES = ds_info.features["label"].num_classes

When the dataset include images with various size, we need to resize them into a shared size. The malaria dataset comes with images of different sizes. Here we resize the images to the input size needed for EfficientNet.

In [None]:
# size of the train and test datasets
size = (IMG_SIZE, IMG_SIZE)

# Resize images to the desired size (ResNet50 input size)
ds_train = ds_train.map(lambda image, label: (tf.image.resize(image, size), label))
ds_test = ds_test.map(lambda image, label: (tf.image.resize(image, size), label))

### Visualizing the data

The following code shows the first 9 images with their labels.

In [None]:
def format_label(label):
    """
    Formats the given label by extracting the second part of a hyphen-separated string.

    Args:
        label (int): The label to be formatted.

    Returns:
        str: The formatted label.

    """
    string_label = label_info.int2str(label)
    return string_label.split("-")[1]


# labels for the first 9 images in the dataset
label_info = ds_info.features["label"]

# iterate through the dataset and plot the first 9 images
for i, (image, label) in enumerate(ds_train.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image.numpy().astype("uint8"))
    plt.title(f"{format_label(label)}")
    plt.axis("off")

### Data augmentation

We can use the preprocessing layers APIs for image augmentation. Here we randomly flip and rotate the input images. We also use random translation and contrast/brightness adjustment as data augmentation.

In [None]:
# Image augmentation layers to be applied to the dataset
img_augmentation_layers = [
    layers.RandomRotation(factor=0.15),
    layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
    layers.RandomFlip(),
    layers.RandomContrast(factor=0.1),
]

# Apply image augmentation layers to the dataset
def img_augmentation(images):
    """
    Apply image augmentation techniques to the given images.

    Args:
        images (numpy.ndarray): Input images to be augmented.

    Returns:
        numpy.ndarray: Augmented images.
    """
    for layer in img_augmentation_layers:
        images = layer(images)
    return images

Here we plot 9 examples of augmentation result of a given figure.

In [None]:
# plot the augmented images for the first 9 images in the dataset
for image, label in ds_train.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        aug_img = img_augmentation(np.expand_dims(image.numpy(), axis=0))
        aug_img = np.array(aug_img)
        plt.imshow(aug_img[0].astype("uint8"))
        plt.title(f"{format_label(label)}")
        plt.axis("off")

### Prepare inputs

Once we verify the input data and augmentation are working correctly, we prepare dataset for training. The input data are resized to uniform `IMG_SIZE`. The labels are put into one-hot (a.k.a. categorical) encoding. The dataset is batched.

In [None]:
def input_preprocess_train(image, label):
    """
    Preprocesses the input image and label for training.

    Args:
        image (Tensor): The input image.
        label (int): The label corresponding to the image.

    Returns:
        Tuple: A tuple containing the preprocessed image and one-hot encoded label.
    """
    image = img_augmentation(image)
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label


def input_preprocess_test(image, label):
    """
    Preprocesses the input image and label for testing.

    Args:
        image: The input image.
        label: The label corresponding to the image.

    Returns:
        The preprocessed image and the one-hot encoded label.
    """
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

# Apply preprocessing to the train dataset and batch the dataset for training
ds_train = ds_train.map(input_preprocess_train, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(batch_size=BATCH_SIZE, drop_remainder=True)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)

# Apply preprocessing to the test dataset and batch the dataset for testing / validation
ds_test = ds_test.map(input_preprocess_test, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(batch_size=BATCH_SIZE, drop_remainder=True)

## Transfer learning from pre-trained weights

Here we initialize the model with pre-trained ImageNet weights, and we train the model on our own dataset.

### Optimizer Tuning

To achieve optimal performance, we need to use a learning rate schedule instead of a single learning rate. Optimizer choice and tuning are important for model performance. We have used `Adam` optimizer with a learning rate schedule defined as follows.

In [None]:
def lr_warmup_cosine_decay(
    global_step,
    warmup_steps,
    hold=0,
    total_steps=0,
    start_lr=0.0,
    target_lr=1e-2,
):
    """
    Computes the learning rate using warmup and cosine decay.

    Args:
        global_step (int): The current global step.
        warmup_steps (int): The number of warmup steps.
        hold (int, optional): The number of steps to hold the learning rate after warmup. Defaults to 0.
        total_steps (int, optional): The total number of steps. Defaults to 0.
        start_lr (float, optional): The initial learning rate. Defaults to 0.0.
        target_lr (float, optional): The target learning rate. Defaults to 1e-2.

    Returns:
        float: The computed learning rate.
    """
    # Cosine decay
    learning_rate = (
        0.5
        * target_lr
        * (
            1
            + ops.cos(
                math.pi
                * ops.convert_to_tensor(
                    global_step - warmup_steps - hold, dtype="float32"
                )
                / ops.convert_to_tensor(
                    total_steps - warmup_steps - hold, dtype="float32"
                )
            )
        )
    )

    warmup_lr = target_lr * (global_step / warmup_steps)
    
    if hold > 0:
        learning_rate = ops.where(
            global_step > warmup_steps + hold, learning_rate, target_lr
        )
    
    learning_rate = ops.where(global_step < warmup_steps, warmup_lr, learning_rate)
    return learning_rate


class WarmUpCosineDecay(schedules.LearningRateSchedule):
    """
    Learning rate schedule that combines warm-up, cosine decay, and hold phases.

    Args:
        warmup_steps (int): Number of steps for the warm-up phase.
        total_steps (int): Total number of steps for the learning rate schedule.
        hold (int): Number of steps to hold the learning rate after the warm-up phase.
        start_lr (float, optional): Initial learning rate. Defaults to 0.0.
        target_lr (float, optional): Target learning rate. Defaults to 1e-2.
    """

    def __init__(self, warmup_steps, total_steps, hold, start_lr=0.0, target_lr=1e-2):
        super().__init__()
        self.start_lr = start_lr
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.hold = hold

    def __call__(self, step):
        lr = lr_warmup_cosine_decay(
            global_step=step,
            total_steps=self.total_steps,
            warmup_steps=self.warmup_steps,
            start_lr=self.start_lr,
            target_lr=self.target_lr,
            hold=self.hold,
        )

        return ops.where(step > self.total_steps, 0.0, lr)

In [None]:
# Total number of images, warmup steps, and hold steps for the learning rate schedule
total_images = 27558
total_steps = (total_images // BATCH_SIZE) * EPOCHS
warmup_steps = int(0.1 * total_steps)
hold_steps = int(0.45 * total_steps)

# Learning rate schedule
schedule = WarmUpCosineDecay(
    start_lr=0.05,
    target_lr=1e-2,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    hold=hold_steps,
)

# Optimizer
optimizer_fn = optimizers.SGD(
    weight_decay=5e-4,
    learning_rate=schedule,
    momentum=0.9,
)

# Loss function
loss_fn = losses.BinaryCrossentropy(label_smoothing=0.1)

### Setup Callbacks

Callbacks are used to implement early stopping, save checkpoints, and log training information such as loss and accuracy.

In [None]:
# Callbacks for the training process
train_callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=2, restore_best_weights=True
    )
]

### Train the model

The first step to transfer learning is to freeze all layers and train only the top layers. Here, validation accuracy and loss will usually be better than training accuracy and loss. This indicates that the model has not yet overfit the training set and has generalized well to the validation set.

In [None]:
def build_model(num_classes):
    """
    Builds a model using EfficientNetB0 architecture for image classification.

    Args:
        num_classes (int): The number of classes for classification.

    Returns:
        keras.Model: The compiled model.

    """

    # Input layer for the model with the given input shape
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

    # Load the EfficientNetB0 model
    model = EfficientNetB0(include_top=False, input_tensor=inputs, weights="imagenet")

    # Freeze the pretrained weights of the model
    model.trainable = False

    # Rebuild top layers of the model for classification
    x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
    x = layers.BatchNormalization()(x)

    top_dropout_rate = 0.2
    x = layers.Dropout(top_dropout_rate, name="top_dropout")(x)
    outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)

    # Compile the model
    model = keras.Model(inputs, outputs, name="EfficientNet")
    optimizer = keras.optimizers.Adam(learning_rate=1e-2)
    model.compile(
        optimizer=optimizer, loss="categorical_crossentropy", metrics=["accuracy"]
    )
    return model

In [None]:
def plot_hist(history):
    """
    Plots the model accuracy and loss for the training and validation sets.

    Parameters:
    history (keras.callbacks.History): The history object returned by the model.fit() function.

    Returns:
    None
    """

    # Plotting the model accuracy and loss for the training and validation sets
    fig, axs = plt.subplots(1,2,figsize=(15,5)) 
    
    # summarize history for accuracy
    axs[0].plot(history.history['accuracy']) 
    axs[0].plot(history.history['val_accuracy']) 
    axs[0].set_title('Model Accuracy')
    axs[0].set_ylabel('Accuracy') 
    axs[0].set_xlabel('Epoch')
    axs[0].legend(['train', 'validate'], loc='upper left')

    # summarize history for loss
    axs[1].plot(history.history['loss']) 
    axs[1].plot(history.history['val_loss']) 
    axs[1].set_title('Model Loss')
    axs[1].set_ylabel('Loss') 
    axs[1].set_xlabel('Epoch')
    axs[1].legend(['train', 'validate'], loc='upper left')
    plt.show()


Fit the model to the training data and validate using the validation data. We use number of epochs as 10, but the model will stop training when the validation loss stops improving. We use early stopping callback to achieve this.

In [None]:
# Build the model
model = build_model(num_classes=NUM_CLASSES)

# Number of epochs for training
epochs = 25  # @param {type: "slider", min:8, max:80}

# Train the model and plot the training history
hist = model.fit(ds_train, epochs=epochs, validation_data=ds_test)
plot_hist(hist)

## Conclusion

- Model accuracy of EfficentNetB0 in the training set is 93.20% and in the validation set is 94.81%.