<a href="https://colab.research.google.com/github/nyp-sit/iti107/blob/main/session-2/3.fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Fine-tuning

Another widely used transfer learning technique is _fine-tuning_.
Fine-tuning involves unfreezing a few of the top layers
of a frozen model base used for feature extraction, and jointly training both the newly added part of the model (in our case, the
fully-connected classifier) and these unfrozen top layers. This is called "fine-tuning" because it slightly adjusts the more abstract
representations of the model being reused, in order to make them more relevant for the problem at hand.



![fine-tuning VGG16](https://nyp-aicourse.s3.ap-southeast-1.amazonaws.com/iti107/resources/vgg16_fine_tuning.png)

In [None]:
import os
import tensorflow as tf
import tensorflow.keras as keras

## Creating Datasets

We will setup our training and validation dataset as we did in earlier exercise.

In [None]:
dataset_url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz'
path_to_zip = tf.keras.utils.get_file(origin=dataset_url, extract=True, cache_dir='.')
dataset_folder = os.path.dirname(path_to_zip)
dataset_folder = os.path.join(dataset_folder, 'flower_photos')

In [None]:
batch_size = 32
image_size = (128,128)

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_folder,
    validation_split=0.2,
    subset="training",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
    label_mode='int'
)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    dataset_folder,
    validation_split=0.2,
    subset="validation",
    seed=1337,
    image_size=image_size,
    batch_size=batch_size,
    label_mode='int'
)

In [None]:
num_classes = len(val_ds.class_names)
print(num_classes)

## Transfer Learning Workflow

It is necessary to freeze the convolution base before training a randomly initialized classifier top. If the classifier wasn't already trained, then the error signal propagating through the network during training would be too large, and the representations previously learned by the layers being fine-tuned would be destroyed. Thus the steps for fine-tuning a network are as follow:

1. Add your custom network on top of an already trained base network.
2. Freeze the convolutional base network.
3. Train the classification top you added.
4. Unfreeze some layers in the base network.
5. Jointly train both these layers and the part you added.


#### BatchNormalization layer

Many CNN models contain BatchNormalization layers.
BatchNormalization contains 2 non-trainable variables that keep track of the mean and variance of the inputs. These variables are updated during training time. Here are a few things to note when fine-tuning model with BatchNormalization layers:
- When you set `bn_layer.trainable = False`, the BatchNormalization layer will run in inference mode, and will not update its mean & variance statistics.
- When you unfreeze a model that contains BatchNormalization layers in order to do fine-tuning, you should keep the BatchNormalization layers in inference mode by passing `training=False` when calling the base model. Otherwise the updates applied to the non-trainable weights will suddenly destroy what the model has learned.

## Build our Model

We will now construct our model: a convolutional base (initialized with pre-trained weights) and our own classification head (initialized with random weights).

In [None]:
data_augmentation = keras.Sequential(
        [
            tf.keras.layers.RandomRotation(0.1),
            tf.keras.layers.RandomFlip("horizontal")
        ]
    )

In [None]:
# Load the pre-trained model
base_model = keras.applications.EfficientNetB0(input_shape=image_size + (3,),
                                         include_top=False,
                                         weights='imagenet')

## This is not necessary as it is just a passthrough. EfficientNet model includes the rescaling layer that preprocess the input
## refer to https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet/preprocess_input
preprocess_input_fn = keras.applications.efficientnet.preprocess_input

# freeze the base layer
base_model.trainable = False

# Add input layer
inputs = keras.layers.Input(shape=image_size+(3,))

x = data_augmentation(inputs)
# Add preprocessing layer

## This is not necessary as it is just a passthrough. EfficientNet model includes the rescaling layer that preprocess the input
## refer to https://www.tensorflow.org/api_docs/python/tf/keras/applications/efficientnet/preprocess_input
x = preprocess_input_fn(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)

# Add our classification head
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(rate=0.5)(x)
x = keras.layers.Dense(units=256, activation="relu")(x)
x = keras.layers.Dropout(rate=0.5)(x)

outputs = keras.layers.Dense(units=num_classes, activation="softmax")(x)

model = keras.models.Model(inputs=[inputs], outputs=[outputs])

base_learning_rate = 0.001

model.compile(loss="sparse_categorical_crossentropy",
                  optimizer=keras.optimizers.Adam(learning_rate=base_learning_rate),
                  metrics=["accuracy"])


Let's confirm all the layers of convolutional base are frozen.

In [None]:
for layer in base_model.layers:
    print(f'layer name = {layer.name}, trainable={layer.trainable}')

Let's print out the model summary and see how many trainable weights. We can see that we only 329,221 trainable weights (parameters), coming from the classification head that put on top of the convolutional base. (For comparison, a EfficientNetB0 has total of 4,049,571 weights).

In [None]:
model.summary()

## Train the classification head

We will go ahead and train our classification head.

In [None]:
# create model checkpoint callback to save the best model checkpoint
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="best_checkpoint",
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

model.fit(train_ds, validation_data=val_ds,
          epochs=50, callbacks=[model_checkpoint_callback])

In [None]:
model.load_weights('best_checkpoint')
model.evaluate(val_ds)

Now we have our classification layers trained, let's start to unfreeze some top layers of the convolutional base to fine tune the weights.
Let's try to fine-tune the last convolutional blocks (i.e. from `block7a` onwards)

Why not fine-tune more layers? Why not fine-tune the entire convolutional base? We could. However, we need to consider that:

* Earlier layers in the convolutional base encode more generic, reusable features, while layers higher up encode more specialized features. It is
more useful to fine-tune the more specialized features, as these are the ones that need to be repurposed on our new problem. There would
be fast-decreasing returns in fine-tuning lower layers.
* The more parameters we are training, the more we are at risk of overfitting. The convolutional base has 4M parameters, so it would be
risky to attempt to train it on our small dataset.

Thus, in our situation, it is a good strategy to try to fine-tune the few layers in the convolutional base.

Let's set this up, we will unfreeze our `base_model`,
and then freeze individual layers inside of it, except the block7a onwards.

In [None]:
for idx, layer in enumerate(base_model.layers):
    if layer.name == 'block7a_expand_conv':
        break
print(idx)

In [None]:
base_model.trainable = True
for layer in base_model.layers[:idx]:
    layer.trainable = False

In [None]:
for layer in base_model.layers:
    print(layer.name, layer.trainable)

Let us examine model summary again. We can see now that we have more trainable weights 1,458,613 compared to previously 329,221.

In [None]:
model.summary()

As you are training a much larger model and want to readapt the pretrained weights, it is important to use a lower learning rate at this stage as we do not want to make too drastic changes to the weights in the convolutional layers under fine-tuning.

In [None]:
finetune_learning_rate = base_learning_rate / 10.

model.compile(loss="sparse_categorical_crossentropy",
              optimizer=keras.optimizers.Adam(learning_rate=finetune_learning_rate),
              metrics=["accuracy"])

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="best_finetune_checkpoint",
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

model.fit(
    train_ds,
    epochs=20,
    validation_data=val_ds,
    callbacks=[model_checkpoint_callback])

In [None]:
model.load_weights('best_finetune_checkpoint')
model.evaluate(val_ds)

**Question:**

Is our fine-tuned model performing better or worse than the previous model?

Provide a possible explanation to your observation.
