**Week-6: Transfer Learning**

Implement the standard LeNet, AlexNet, VGG CNN architecture model to classify multicategory image dataset.

MNIST handwritten digits (0-9)

Note down accuracies obtained for epochs 5, 50, 250.



In [10]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

# Load and preprocess MNIST
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Reshape and normalize
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0

# One-hot encode labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [16]:
def build_lenet():
    model = models.Sequential([
        layers.Input(shape=(28, 28, 1)),  # Replaces input_shape in Conv2D
        layers.Conv2D(6, (5, 5), activation='tanh', padding='same'),
        layers.AveragePooling2D(pool_size=(2, 2)),  # FIX: added pool_size
        layers.Conv2D(16, (5, 5), activation='tanh'),
        layers.AveragePooling2D(pool_size=(2, 2)),  # FIX: added pool_size
        layers.Flatten(),
        layers.Dense(120, activation='tanh'),
        layers.Dense(84, activation='tanh'),
        layers.Dense(10, activation='softmax')
    ])
    return model

In [17]:
def build_alexnet():
    model = models.Sequential([
        layers.Input(shape=(28, 28, 1)),  # NEW
        layers.Conv2D(96, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(384, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(384, (3, 3), activation='relu', padding='same'),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')
    ])
    return model

In [18]:
def build_vgg():
    model = models.Sequential([
        layers.Input(shape=(28, 28, 1)),  # NEW
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(),
        layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D(),
        layers.Flatten(),
        layers.Dense(256, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(10, activation='softmax')
    ])
    return model

In [14]:
def train_and_log(model_fn, model_name):
    model = model_fn()
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

    class AccuracyLogger(tf.keras.callbacks.Callback):
        def __init__(self):
            self.accuracies = {}

        def on_epoch_end(self, epoch, logs=None):
            if epoch + 1 in [5, 15, 25]:
                self.accuracies[epoch + 1] = logs["val_accuracy"]

    logger = AccuracyLogger()

    model.fit(
        x_train, y_train,
        epochs=25,
        batch_size=128,
        validation_data=(x_test, y_test),
        verbose=0,
        callbacks=[logger]
    )

    print(f"{model_name} Accuracies:")
    for epoch in [5, 15, 25]:
        acc = logger.accuracies.get(epoch, None)
        print(f"Epoch {epoch}: {acc:.4f}" if acc is not None else f"Epoch {epoch}: Not recorded")

In [19]:
train_and_log(build_lenet, "LeNet")
train_and_log(build_alexnet, "AlexNet")
train_and_log(build_vgg, "VGG")

LeNet Accuracies:
Epoch 5: 0.9827
Epoch 15: 0.9872
Epoch 25: 0.9862
AlexNet Accuracies:
Epoch 5: 0.9855
Epoch 15: 0.9893
Epoch 25: 0.9899
VGG Accuracies:
Epoch 5: 0.9925
Epoch 15: 0.9925
Epoch 25: 0.9928
