# ResNet50 Transfer Learning Example in Google Colab
This notebook demonstrates how to use ResNet50 with Transfer Learning on CIFAR-10 dataset.


In [None]:
# Supports: lenet, alexnet, mobilenetv2, resnet50, efficientnetb0, vgg16

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import cifar10

In [None]:
# Enable mixed precision for memory saving
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')

In [None]:
MODEL_NAME = "resnet50"
EPOCHS = 3
BATCH_SIZE = 8
IMG_SIZE = (128, 128)  # resized on the fly

In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

In [None]:
# Normalize pixel values
x_train = x_train / 255.0
x_test = x_test / 255.0

In [None]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).map(
    lambda x, y: (tf.image.resize(x, IMG_SIZE), y), num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).map(
    lambda x, y: (tf.image.resize(x, IMG_SIZE), y), num_parallel_calls=tf.data.AUTOTUNE
).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [None]:
def build_model(name):
    if name == "lenet":
        model = models.Sequential([
            layers.Conv2D(6, (5,5), activation='relu', input_shape=IMG_SIZE + (3,)),
            layers.AveragePooling2D(),
            layers.Conv2D(16, (5,5), activation='relu'),
            layers.AveragePooling2D(),
            layers.Flatten(),
            layers.Dense(120, activation='relu'),
            layers.Dense(84, activation='relu'),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    elif name == "alexnet":
        model = models.Sequential([
            layers.Conv2D(96, (11,11), strides=4, activation='relu', input_shape=IMG_SIZE + (3,)),
            layers.BatchNormalization(),
            layers.MaxPooling2D(3, strides=2),
            layers.Conv2D(256, (5,5), padding='same', activation='relu'),
            layers.BatchNormalization(),
            layers.MaxPooling2D(3, strides=2),
            layers.Conv2D(384, (3,3), padding='same', activation='relu'),
            layers.Conv2D(384, (3,3), padding='same', activation='relu'),
            layers.Conv2D(256, (3,3), padding='same', activation='relu'),
            layers.MaxPooling2D(3, strides=2),
            layers.Flatten(),
            layers.Dense(4096, activation='relu'),
            layers.Dropout(0.5),
            layers.Dense(4096, activation='relu'),
            layers.Dropout(0.5),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    elif name == "mobilenetv2":
        base_model = keras.applications.MobileNetV2(
            input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
        base_model.trainable = False
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    elif name == "resnet50":
        base_model = keras.applications.ResNet50(
            input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
        base_model.trainable = False
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    elif name == "efficientnetb0":
        base_model = keras.applications.EfficientNetB0(
            input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
        base_model.trainable = False
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    elif name == "vgg16":
        base_model = keras.applications.VGG16(
            input_shape=IMG_SIZE + (3,), include_top=False, weights='imagenet')
        base_model.trainable = False
        model = models.Sequential([
            base_model,
            layers.GlobalAveragePooling2D(),
            layers.Dense(10, activation='softmax', dtype='float32')
        ])

    return model

In [None]:
model = build_model(MODEL_NAME)
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
history = model.fit(train_ds, validation_data=test_ds, epochs=EPOCHS)

In [None]:
test_loss, test_acc = model.evaluate(test_ds)
print(f"✅ Test Accuracy: {test_acc:.4f}")