# **Transfer Learning with TensorFlow**

## **What is Transfer Learning?**

**Transfer Learning** reuses a **pre-trained model** for a new task.  
Instead of training from scratch, we adapt knowledge learned from a **large dataset** to a **smaller dataset**.

---

## **Model: `MobileNetV2`**

- Pre-trained on **`ImageNet`**
- Lightweight and efficient
- Suitable for **image classification**

---

## **Dataset: `CIFAR-10`**

- 60,000 images  
- 32 × 32 pixels  
- 10 classes  

---

## **Goal**

- Load **`MobileNetV2`**
- Replace the final layer
- Fine-tune on **`CIFAR-10`**
- Improve performance

---

## **Why Use Transfer Learning?**

- Faster training  
- Requires less data  
- Better accuracy  

Transfer Learning enables efficient and high-performing computer vision models.

# **`Step 1: Import Libraries and Load Data`**

In [None]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Load CIFAR-10 data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize pixel values and convert labels to one-hot encoding
x_train, x_test = x_train / 255.0, x_test / 255.0
y_train, y_test = to_categorical(y_train, 10), to_categorical(y_test, 10)

# **Data Augmentation**

## **What is Data Augmentation?**

**Data Augmentation** increases the **size** and **diversity** of training data by creating modified versions of existing data without collecting new samples.

Common techniques (for images):

- **Rotation**
- **Flipping**
- **Scaling / Zooming**
- **Cropping**
- **Adding Noise**
- **Brightness Adjustment**

---

## **Why It Matters**

- **Prevents Overfitting** – Improves generalization  
- **Boosts Performance** – More diverse training data  
- **Handles Imbalance** – Augments minority classes  

> Always choose augmentations that make sense for your task (e.g., avoid flipping text images).

---

# **`Step 2: Data Preprocessing`**

## **Image Size Adjustment**

- **`MobileNetV2`** expects: **224 × 224**
- **`CIFAR-10`** images are: **32 × 32**

### **Required Steps**

- **Resize to 224 × 224**
- Apply **Data Augmentation**
- **Normalize / Rescale** pixel values

Proper **Augmentation + Preprocessing** ensures compatibility with **`MobileNetV2`** and improves model performance.

In [None]:
# Data Augmentation
datagen = ImageDataGenerator(
    featurewise_center = False,
    samplewise_center = False,
    featurewise_std_normalization = False,
    samplewise_std_normalization = False,
    zca_whitening = False,
    rotation_range = 15,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    horizontal_flip = True ,
    vertical_flip = True,
    zoom_range = 0.1,
    fill_mode = 'nearest'
)

datagen.fit(x_train)

# **`Step 3: Modify the Pre-Trained Model`**

## Why Modify the Model?

MobileNetV2 is pre-trained on ImageNet with 1000 classes.  
Our task (CIFAR-10) has only 10 classes, so we must:

- Remove the original top layer  
- Add a new classifier for 10 classes  
- Freeze the base layers (initially)

---

## What We Do

### 1. Load MobileNetV2 Without Top Layer
- `include_top=False`
- Keep feature extraction layers
- Remove the 1000-class classifier

### 2. Add New Classifier
- Global Average Pooling
- Dense layer with 10 output units
- Softmax activation

We use fewer classes (10 instead of 1000).

### 3. Freeze Base Layers
- Pre-trained weights are not updated
- Only the new classifier trains first
- Later fine-tuning can be applied

In [None]:
# Load MobileNetV2 without the top layer
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(32, 32, 3))

# Freeze the base_model
base_model.trainable = False

# Add custom layers on top for our task
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)  # New FC layer, random init
predictions = Dense(10, activation='softmax')(x)  # New softmax layer

# Define the model
model = Model(inputs=base_model.input, outputs=predictions)

# **`Step 4: Compile and Train the Model`**

In [None]:
%%time
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train) // 32, epochs=10,
                    validation_data=(x_test, y_test), verbose=1)

In [None]:
import matplotlib.pyplot as plt
# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

**An extra run with more epochs and Increased Batch Size**

In [None]:
%%time
#early stopping
from tensorflow.keras.callbacks import EarlyStopping
early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
model.compile(optimizer=tf.keras.optimizers.legacy.Adam(),
loss='categorical_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(datagen.flow(x_train, y_train, batch_size=64),
                    steps_per_epoch=len(x_train) / 64, epochs=50,
                    validation_data=(x_test, y_test), verbose=1, callbacks=[early_stopping])
import matplotlib.pyplot as plt
# Plot training & validation accuracy values
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Step 5: Fine-Tuning (Optional)

After training the model with frozen base layers, you can improve performance by **unfreezing some layers** of MobileNetV2 and training again.

Fine-tuning allows the pre-trained weights to adjust to your specific dataset.

---

## Why Fine-Tuning?

- Improves accuracy
- Adapts feature extraction to your task
- Better performance on custom datasets

---

## Important Before Fine-Tuning

- Unfreeze selected layers of the base model
- Recompile the model with a lower learning rate

Recompiling is necessary because model parameters are changing.

In [None]:
# Unfreeze some layers in the base model
base_model.trainable = True
fine_tune_at = 100  # This is the number of layers from the top to freeze
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

# Recompile the model
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Continue training
history_fine = model.fit(datagen.flow(x_train, y_train, batch_size=32),
                         steps_per_epoch=len(x_train) // 32, epochs=5,
                         validation_data=(x_test, y_test), verbose=1)

This example demonstrates the basics of applying transfer learning with TensorFlow to improve performance on a computer vision task using a smaller dataset. Fine-tuning and data augmentation are powerful techniques to increase accuracy further and adapt the pre-trained model to the new task more effectively.

In [None]:
import matplotlib.pyplot as plt

def plot_history(histories, key='accuracy'):
    plt.figure(figsize=(16, 4))

    for name, history in histories:
        val = plt.plot(history.epoch, history.history['val_'+key],
                       '--', label=name.title()+' Val')
        plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
                 label=name.title()+' Train')

    plt.xlabel('Epochs')
    plt.ylabel(key.replace('_', ' ').title())
    plt.legend()
    plt.xlim([0, max(history.epoch)])

# Plot accuracy
plot_history([('Pre Fine-Tuning', history),
              ('Fine-Tuning', history_fine)],
             key='accuracy')

# Plot loss
plot_history([('Pre Fine-Tuning', history),
              ('Fine-Tuning', history_fine)],
             key='loss')