# üìò Transfer Learning Comparison ‚Äì Multiple Models on CIFAR-10 (Keras)

In this notebook, you can **show students end-to-end transfer learning** on **real data**
using **multiple pretrained CNN architectures**:

- VGG16
- ResNet50V2
- InceptionV3
- Xception
- DenseNet121
- MobileNetV2

For each model we will:
1. Prepare data (CIFAR-10 subset: `cat`, `dog`, `horse`).
2. Build a transfer learning model (pretrained on ImageNet).
3. Train for a few epochs (demo-level, not full training).
4. Evaluate test accuracy.
5. Compare models on a bar chart.
6. Show predictions from the **best model**.

This is designed as a **teaching notebook**: simple, visual, and easy to extend.

## 1. Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.datasets import cifar10

from tensorflow.keras.applications import (
    VGG16,
    ResNet50V2,
    InceptionV3,
    Xception,
    DenseNet121,
    MobileNetV2,
)

from tensorflow.keras.applications.vgg16 import preprocess_input as vgg16_preprocess
from tensorflow.keras.applications.resnet_v2 import preprocess_input as resnet50v2_preprocess
from tensorflow.keras.applications.inception_v3 import preprocess_input as inceptionv3_preprocess
from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
from tensorflow.keras.applications.densenet import preprocess_input as densenet_preprocess
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input as mobilenetv2_preprocess

print('TensorFlow:', tf.__version__)

## 2. Load CIFAR-10 Dataset

We will:
- Load CIFAR-10 (60k images of size 32√ó32√ó3).
- Use only 3 classes for a fast demo:
  - `cat` (label 3)
  - `dog` (label 5)
  - `horse` (label 7)
- Subsample the dataset to keep training quick.


In [None]:
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

class_names_full = ['airplane', 'automobile', 'bird', 'cat', 'deer',
                    'dog', 'frog', 'horse', 'ship', 'truck']
print('Original train shape:', X_train.shape, y_train.shape)
print('Original test shape:', X_test.shape, y_test.shape)
print('All classes:', class_names_full)

### 2.1 Visualise Some Original CIFAR-10 Images

In [None]:
plt.figure(figsize=(8, 4))
for i in range(12):
    plt.subplot(3, 4, i + 1)
    plt.imshow(X_train[i])
    plt.axis('off')
    plt.title(class_names_full[int(y_train[i])])
plt.tight_layout()
plt.show()

### 2.2 Filter to 3 Classes: `cat`, `dog`, `horse`

We keep only labels 3, 5, 7 and remap them to **0, 1, 2**.

- `0 ‚Üí cat`
- `1 ‚Üí dog`
- `2 ‚Üí horse`

In [None]:
selected_classes = [3, 5, 7]  # cat, dog, horse

def filter_classes(X, y, selected):
    idx = np.isin(y, selected).flatten()
    X_sel = X[idx]
    y_sel = y[idx]
    label_map = {orig: i for i, orig in enumerate(selected)}
    y_sel_mapped = np.vectorize(label_map.get)(y_sel)
    return X_sel, y_sel_mapped

X_train_sel, y_train_sel = filter_classes(X_train, y_train, selected_classes)
X_test_sel, y_test_sel = filter_classes(X_test, y_test, selected_classes)

print('Filtered train shape:', X_train_sel.shape, y_train_sel.shape)
print('Filtered test shape:', X_test_sel.shape, y_test_sel.shape)
print('Unique filtered labels:', np.unique(y_train_sel))
demo_class_names = ['cat', 'dog', 'horse']
print('Demo classes:', demo_class_names)

### 2.3 Subsample for Fast Demo

To run multiple models in one notebook live, we'll restrict the dataset size.
You can increase these numbers on a stronger machine.

In [None]:
def subsample(X, y, n):
    n = min(n, X.shape[0])
    idx = np.random.permutation(X.shape[0])[:n]
    return X[idx], y[idx]

X_train_small, y_train_small = subsample(X_train_sel, y_train_sel, 3000)
X_test_small, y_test_small = subsample(X_test_sel, y_test_sel, 900)

print('Small train shape:', X_train_small.shape, y_train_small.shape)
print('Small test shape:', X_test_small.shape, y_test_small.shape)

## 3. Helper Functions

We now write helper functions to:
- Resize & preprocess images for each model.
- Build a transfer learning model.
- Train and evaluate the model.

In [None]:
NUM_CLASSES = 3

def prepare_data_for_model(X_train, y_train, X_test, y_test, input_size, preprocess_fn):
    """Resize and preprocess data for a specific model.

    input_size: (H, W)
    preprocess_fn: Keras preprocess_input function for that model.
    """
    H, W = input_size
    # Resize from 32x32 -> HxW
    X_train_resized = tf.image.resize(X_train, (H, W)).numpy()
    X_test_resized = tf.image.resize(X_test, (H, W)).numpy()

    # Apply model-specific preprocessing
    X_train_pp = preprocess_fn(X_train_resized)
    X_test_pp = preprocess_fn(X_test_resized)

    y_train_oh = to_categorical(y_train, NUM_CLASSES)
    y_test_oh = to_categorical(y_test, NUM_CLASSES)

    return X_train_pp, y_train_oh, X_test_pp, y_test_oh


def build_tl_model(base_model_fn, input_size, preprocess_fn, name='model'):
    """Build a transfer learning model using a Keras applications backbone.

    base_model_fn: function that returns a keras Model with include_top=False
    input_size: (H, W)
    preprocess_fn: kept for reference (not used here directly)
    """
    H, W = input_size

    base_model = base_model_fn(weights='imagenet', include_top=False,
                               input_shape=(H, W, 3))
    base_model.trainable = False  # start with feature extraction

    inputs = layers.Input(shape=(H, W, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(NUM_CLASSES, activation='softmax')(x)

    model = models.Model(inputs, outputs, name=name)
    model.compile(optimizer=Adam(1e-3),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model


def plot_history(history, title_prefix='Model'):
    hist = history.history
    epochs = range(1, len(hist['loss']) + 1)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, hist['loss'], label='Train Loss')
    plt.plot(epochs, hist['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{title_prefix} - Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, hist['accuracy'], label='Train Acc')
    plt.plot(epochs, hist['val_accuracy'], label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(f'{title_prefix} - Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()


def show_predictions(X_raw, X_pp, y_true, model, class_labels, n=9, title_prefix=''):
    preds = model.predict(X_pp[:n])
    y_pred = np.argmax(preds, axis=1)
    y_true_flat = y_true[:n]

    plt.figure(figsize=(10, 10))
    for i in range(n):
        plt.subplot(3, 3, i + 1)
        plt.imshow(X_raw[i])
        plt.axis('off')
        true_label = class_labels[int(y_true_flat[i])]
        pred_label = class_labels[int(y_pred[i])]
        conf = np.max(preds[i])
        plt.title(f'T:{true_label} | P:{pred_label}\n({conf:.2f})')
    plt.suptitle(title_prefix)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

## 4. Define Model Configurations

We prepare a list of model configurations so we can **loop over all models**:

- Input size (224√ó224 or 299√ó299)
- Backbone function
- Preprocess function
- A short description

In [None]:
MODEL_CONFIGS = [
    {
        'name': 'VGG16',
        'input_size': (224, 224),
        'base_fn': VGG16,
        'preprocess_fn': vgg16_preprocess,
        'description': 'Classic deep CNN, good feature extractor but heavy.'
    },
    {
        'name': 'ResNet50V2',
        'input_size': (224, 224),
        'base_fn': ResNet50V2,
        'preprocess_fn': resnet50v2_preprocess,
        'description': 'Residual network with skip connections, strong baseline.'
    },
    {
        'name': 'InceptionV3',
        'input_size': (299, 299),
        'base_fn': InceptionV3,
        'preprocess_fn': inceptionv3_preprocess,
        'description': 'Multi-scale convs, good accuracy vs compute.'
    },
    {
        'name': 'Xception',
        'input_size': (299, 299),
        'base_fn': Xception,
        'preprocess_fn': xception_preprocess,
        'description': 'Depthwise separable convs, powerful but heavier.'
    },
    {
        'name': 'DenseNet121',
        'input_size': (224, 224),
        'base_fn': DenseNet121,
        'preprocess_fn': densenet_preprocess,
        'description': 'Dense connections, parameter-efficient.'
    },
    {
        'name': 'MobileNetV2',
        'input_size': (224, 224),
        'base_fn': MobileNetV2,
        'preprocess_fn': mobilenetv2_preprocess,
        'description': 'Lightweight, fast, great for mobile/edge.'
    },
]

for cfg in MODEL_CONFIGS:
    print(f"{cfg['name']}: input_size={cfg['input_size']} - {cfg['description']}")

## 5. Train & Evaluate All Models (Feature Extraction)

For each model configuration:

1. Resize + preprocess CIFAR-10 subset to required input size.
2. Build a transfer learning model with frozen backbone.
3. Train for a **small number of epochs** (demo only).
4. Evaluate on the test set.
5. Store accuracy for comparison.

‚ö†Ô∏è **Note:** Running all models may take a few minutes depending on your machine. You can:
- Reduce `EPOCHS_PER_MODEL`
- Reduce subsample sizes above
- Comment out some configs in `MODEL_CONFIGS`.

In [None]:
EPOCHS_PER_MODEL = 3   # adjust as needed
BATCH_SIZE = 32
VAL_SPLIT = 0.2

results = []
trained_models = {}
preprocessed_data_cache = {}  # cache preprocessed data per (H,W) to save time

for cfg in MODEL_CONFIGS:
    name = cfg['name']
    input_size = cfg['input_size']
    base_fn = cfg['base_fn']
    preprocess_fn = cfg['preprocess_fn']

    print("\n" + "="*60)
    print(f"Training model: {name} (input_size={input_size})")
    print("="*60)

    # Cache preprocessed data per input size
    if input_size in preprocessed_data_cache:
        X_train_pp, y_train_oh, X_test_pp, y_test_oh = preprocessed_data_cache[input_size]
    else:
        X_train_pp, y_train_oh, X_test_pp, y_test_oh = prepare_data_for_model(
            X_train_small, y_train_small,
            X_test_small, y_test_small,
            input_size,
            preprocess_fn,
        )
        preprocessed_data_cache[input_size] = (X_train_pp, y_train_oh, X_test_pp, y_test_oh)

    model = build_tl_model(base_fn, input_size, preprocess_fn, name=name)

    history = model.fit(
        X_train_pp, y_train_oh,
        epochs=EPOCHS_PER_MODEL,
        batch_size=BATCH_SIZE,
        validation_split=VAL_SPLIT,
        verbose=1,
    )

    test_loss, test_acc = model.evaluate(X_test_pp, y_test_oh, verbose=0)
    print(f"\n[{name}] Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

    results.append({'model': name, 'test_loss': float(test_loss), 'test_acc': float(test_acc)})
    trained_models[name] = {
        'model': model,
        'history': history,
        'input_size': input_size,
        'preprocess_fn': preprocess_fn,
    }

print("\nAll models trained.")

## 6. Compare Model Accuracies

Let's show test accuracies for all models in a table and a bar chart.

In [None]:
import pandas as pd

results_df = pd.DataFrame(results)
results_df.sort_values('test_acc', ascending=False, inplace=True)
display(results_df)

plt.figure(figsize=(8, 4))
plt.bar(results_df['model'], results_df['test_acc'])
plt.ylabel('Test Accuracy')
plt.ylim(0, 1)
plt.title('Transfer Learning ‚Äì CIFAR-10 (cat/dog/horse)')
plt.xticks(rotation=30)
plt.tight_layout()
plt.show()

## 7. Visualise Training Curves & Predictions for Best Model

We'll pick the **best test accuracy** model and:

- Plot its training vs validation loss/accuracy curves.
- Show **input images** with predicted vs true labels.

In [None]:
# Pick best model by test accuracy
best_row = results_df.iloc[0]
best_name = best_row['model']
print('Best model:', best_name)

best_info = trained_models[best_name]
best_model = best_info['model']
best_history = best_info['history']
best_input_size = best_info['input_size']
best_preprocess_fn = best_info['preprocess_fn']

plot_history(best_history, title_prefix=f'{best_name}')

# Prepare data again for that model (from raw small arrays)
X_train_pp_best, y_train_oh_best, X_test_pp_best, y_test_oh_best = prepare_data_for_model(
    X_train_small, y_train_small,
    X_test_small, y_test_small,
    best_input_size,
    best_preprocess_fn,
)

show_predictions(
    X_test_small,
    X_test_pp_best,
    y_test_small,
    best_model,
    demo_class_names,
    n=9,
    title_prefix=f'Best model: {best_name}'
)

## 8. Optional: Fine-Tune the Best Model

To show students **fine-tuning**, you can unfreeze some layers of the best model's backbone
and retrain with a **lower learning rate**.

This section is optional and may take more time.
You can skip in live demos or run with fewer epochs.

In [None]:
# Example: fine-tune last N layers of the backbone for the best model
# (Run only if you want to demonstrate fine-tuning.)

FINE_TUNE_EPOCHS = 3
FINE_TUNE_LR = 1e-5

# Extract base model from best_model (first layer after Input)
base_layer = best_model.layers[1]
print('Backbone layer name:', base_layer.name)

# Unfreeze backbone
base_layer.trainable = True

# Optionally, freeze all but last 30 layers
if hasattr(base_layer, 'layers'):
    for l in base_layer.layers[:-30]:
        l.trainable = False

best_model.compile(
    optimizer=Adam(FINE_TUNE_LR),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

print('Fine-tuning the best model...')
history_ft = best_model.fit(
    X_train_pp_best, y_train_oh_best,
    epochs=FINE_TUNE_EPOCHS,
    batch_size=BATCH_SIZE,
    validation_split=VAL_SPLIT,
    verbose=1,
)

test_loss_ft, test_acc_ft = best_model.evaluate(X_test_pp_best, y_test_oh_best, verbose=0)
print(f'[Fine-tuned {best_name}] Test Loss: {test_loss_ft:.4f}, Test Accuracy: {test_acc_ft:.4f}')

plot_history(history_ft, title_prefix=f'{best_name} (Fine-tuned)')

---
## 9. How to Use This Notebook in Class

- Start from section 2 ‚Üí explain dataset.
- Show how we **reuse knowledge** from ImageNet models.
- Run section 5 for a couple of models (or all, if time allows).
- Discuss the **accuracy comparison** in section 6.
- Use section 7 to visually show **predictions**.
- (Optional) Run fine-tuning to show how accuracy can further improve.

You can ask students to:
- Change which CIFAR-10 classes are used.
- Add/remove models from `MODEL_CONFIGS`.
- Change number of epochs / batch size.
- Add **data augmentation** for better generalisation.

This gives a **unified, visual, hands-on** understanding of transfer learning across many backbones. üß†üìä