# Transfer Learning - FINETUNED

## Setup the pretrained model


In [None]:
"""Build a convolutional neural network (CNN)

Test data used: CIFAR-10
Project consists of three phases
* learn the CNN from scratch
* use standard transfer learning
* Finetune the transfer learning

"""

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import VGG16
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# from google.colab import drive
# drive.mount('/content/drive')

# Load CIFAR-10 dataset
data = (train_images, train_labels), (test_images, test_labels) = cifar10.load_data()

# Normalize pixel values to be between 0 and 1
train_images, test_images = train_images / 255.0, test_images / 255.0

# Define VGG16 model with pretrained weights
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

# Unfreeze the last 4 layers for fine-tuning
for layer in base_model.layers[-3:]:
  layer.trainable = True

In [None]:
# 00 Add custom layers on top of the VGG16 base
cnn_model = models.Sequential([
    base_model,
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dropout(0.15), # ++
    layers.Dense(10, activation='softmax')
])

# Freeze the base_model layers
# for layer in base_model.layers:
#     layer.trainable = False


## Compile

In [None]:
# Compile the model
cnn_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Define early stopping callback
early_stop = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

## Train


In [None]:
# Train the model
history = cnn_model.fit(train_images, train_labels, epochs=50,
                    validation_data=(test_images, test_labels),
                    callbacks=[early_stop])

In [None]:
# Save the model to Google Drive
#cnn_model.save('/content/drive/My Drive/cifar10_vgg16_model.h5')

## Evaluate the results

In [None]:
# Evaluate the model on the test data
test_loss, test_accuracy = cnn_model.evaluate(test_images, test_labels)

# Predict the labels for the test set
test_predictions = cnn_model.predict(test_images)
test_pred_labels = test_predictions.argmax(axis=1)

# Add labels
label_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Calculate performance metrics
accuracy = accuracy_score(test_labels, test_pred_labels)
precision = precision_score(test_labels, test_pred_labels, average='macro')
recall = recall_score(test_labels, test_pred_labels, average='macro')
f1 = f1_score(test_labels, test_pred_labels, average='macro')

print(f"Test Accuracy: {test_accuracy}")
print(f"Test Precision: {precision}")
print(f"Test Recall: {recall}")
print(f"Test F1 Score: {f1}")

# Generate and display the classification report
class_report = classification_report(test_labels, test_pred_labels)
print("Classification Report:")
print(class_report)

# Generate the confusion matrix
conf_matrix = confusion_matrix(test_labels, test_pred_labels)

# Plot the confusion matrix
plt.figure(figsize=(10, 8))
cm_display = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=range(10))
cm_display.plot(cmap=plt.cm.Blues)
plt.title("Confusion Matrix")
plt.show()

# Plot the training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Training and Validation Loss')

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and Validation Accuracy')

plt.show()