# Transfer Learning Example
From https://www.tensorflow.org/tutorials/images/transfer_learning

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

### Download the Data

In [None]:
data_url = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=data_url, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

In [None]:
train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

### Create Training and Validation Datasets from the Downloaded Data

In [None]:
BATCH_SIZE=32
IMG_SIZE = (160, 160)

In [None]:
train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

In [None]:
validation_dataset = tf.keras.utils.image_dataset_from_directory(validation_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

### Show the First Few Images With Class Names

In [None]:
class_names = train_dataset.class_names

In [None]:
plt.figure(figsize=(10,10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

### Create a Test Set from a Subset of the Validation Set

In [None]:
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

In [None]:
print(f"Number of Validation Batches: {tf.data.experimental.cardinality(validation_dataset)}")
print(f"Number of Test Batches: {tf.data.experimental.cardinality(test_dataset)}")

### Configure Prefetching to Increase Performance
Prefetching will allow tensorflow to begin pre-processing the next batch of images while training on the current set.

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

### Augment the Dataset (Create Modified Copies of Existing Dataset to Artifically Increase Size of Dataset)

In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2)])

In [None]:
for image, _ in train_dataset.take(1):
    plt.figure(figsize=(10,10))
    first_image = image[0]
    for i in range(9):
        ax = plt.subplot(3,3,i+1)
        augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
        plt.imshow(augmented_image[0] / 255)
        plt.axis('off')

### Preprocess the Input For MobileNetV2

In [None]:
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

In [None]:
##Alternative preprocessing option is to use tf.keras.layers.Rescaling 
# rescale = tf.keras.layers.Rescaling(1./127.5, offset=-1)

### Create Base Model (excluding the Top Layer)

In [None]:
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape = IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

In [None]:
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)

### Freeze the Convolutional Base (prevents weights in base_model layers from being updated during training)

In [None]:
base_model.trainable = False

In [None]:
base_model.summary()

### Add a Classification Head

The global average layer will convert the features into a single vector per image

In [None]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

A Dense layer will create a single prediction per image.
(No activation function needed because the prediction will be treated as a logit 
where positive numbers are class 1 and negative numbers are class 0)

In [None]:
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

### Chain All of the Steps Above Together to Build the Model

In [None]:
inputs = tf.keras.Input(shape=(160,160,3))
x = data_augmentation(inputs)
x = preprocess_input(x)

#When you unfreeze a model that contains BatchNormalization layers in order to do 
# fine-tuning, you should keep the BatchNormalization layers in inference mode by 
# passing training = False when calling the base model. Otherwise, the updates 
# applied to the non-trainable weights will destroy what the model has learned.
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

### Compile the Model

In [None]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
                                                 loss = tf.keras.losses.BinaryCrossentropy(from_logits=True),
                                                 metrics=["accuracy"])

In [None]:
model.summary() # note the number of trainable vs non-trainable params

In [None]:
print("Trainable Variables:", len(model.trainable_variables))
model.trainable_variables

### Train the Model

In [None]:
initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)

In [None]:
print(f"Initial Loss: {loss0}")
print(f"Initial Accuracy: {accuracy0}")

In [None]:
history = model.fit(train_dataset,
                    epochs = initial_epochs,
                    validation_data=validation_dataset)

### Plot the Learning Curves

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

### Fine Tune the Model by Unfreezing Just the Top Layers of the Base Model

In [None]:
base_model.trainable = True

In [None]:
print(f"Number of layers in base model: {len(base_model.layers)}")

In [None]:
fine_tune_at = 100

#Freeze all layers before the fine_tune_at layer
for layer in base_model.layers[0:fine_tune_at]:
    layer.trainable = False

### Recompile The Model
Use a lower learning rate since the model being trained is much larger and could overfit quickly

In [None]:
model.compile(loss = tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
              metrics = ["accuracy"])

In [None]:
model.summary() # note the number of trainable params now

In [None]:
len(model.trainable_variables)

### Continue Training the Model

In [None]:
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs = total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)

### Plot the New Learning Curves

In [None]:
acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

In [None]:
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

### Make Predictions on the Test Set

In [None]:
loss, accuracy = model.evaluate(test_dataset)

In [None]:
print(loss, accuracy)

In [None]:
# Retrieve images from test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply Sigmoid since model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)
print(f"Predictions:\n {predictions.numpy()}")
print(f"Labels:\n {label_batch}")

plt.figure(figsize=(10,10))
for i in range(9):
    ax = plt.subplot(3,3,i+1)
    plt.imshow(image_batch[i].astype("uint8"))
    plt.title(class_names[predictions[i]])
    plt.axis("off")

### Next
See https://www.tensorflow.org/guide/keras/transfer_learning for more info on transfer learning