##### ARTI 560 - Computer Vision

## Image Classification with Vision Transformer (ViT)

### Introduction

In this notebook, we build an image classification model to recognize five different flower categories from the TF Flowers dataset using a modern deep learning architecture called the Vision Transformer (ViT).

Unlike traditional Convolutional Neural Networks (CNNs), Vision Transformers apply the Transformer concept (originally designed for Natural Language Processing) to images by splitting them into small patches and learning relationships between these patches using attention mechanisms. This approach has shown strong performance in many computer vision tasks.

We will use a pretrained ViT model (trained on ImageNet) and adapt it to our flower dataset using transfer learning, which allows us to achieve good results with limited training time and a relatively small dataset.

### Overview

This notebook follows the steps below:

1. **Load the Dataset**

We load the TF Flowers dataset from TensorFlow Datasets (TFDS) and split it into:

- Training set (70%)
- Validation set (15%)
- Test set (15%)

2. **Preprocess the Images**

Since the pretrained ViT model expects images of size **224×224**, we:

- Resize all images to 224×224
- Normalize pixel values to the range [0, 1]

3. **Create Efficient Data Pipelines**

We use the tf.data API to build efficient pipelines with:

- Shuffling (for training)
- Batching
- Prefetching (to improve performance)

4. **Apply Data Augmentation**

To improve generalization, we apply augmentation layers such as:

- Random horizontal flip
- Random rotation
- Random zoom


5️. **Build a ViT Model with a Frozen Backbone**

We load a pretrained Vision Transformer model using keras_hub and freeze the backbone so that:

- The pretrained feature extractor remains unchanged
- Only the classification head learns to adapt to the flower classes

6️. **Train and Evaluate**

Finally, we train the model for a few epochs and evaluate its performance on the unseen test set to report the final accuracy.

In [None]:
# Import libraries
import tensorflow as tf
from tensorflow import keras
from keras import layers
import keras_hub #provides pretrained models like: ViT, Bert
import tensorflow_datasets as tfds
import numpy as np


# Hyperparameters
image_size = 224 #image size for the preset model
batch_size = 32
num_classes = 5
epochs = 5
learning_rate = 1e-4

In [None]:
# Load TF Flowers dataset with proper splits
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    "tf_flowers",
    split=["train[:70%]", "train[70%:85%]", "train[85%:]"],
    as_supervised=True, # Returns (image, label) pairs.
    with_info=True # Returns metadata like: number of samples, label names and image shape
)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/tf_flowers/3.0.1...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/tf_flowers/incomplete.ACW1N5_3.0.1/tf_flowers-train.tfrecord*...:   0%|   …

Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data.


In [None]:
# Prepare dataset

def preprocess(image, label):
    image = tf.image.resize(image, [image_size, image_size])
    image = tf.cast(image, tf.float32) / 255.0  # normalize to [0,1] and cast raw images from uint8 to float32
    return image, label

train_ds = (ds_train
            .shuffle(1000)
            .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE)) #prepare the next batch on the CPU while the GPU is training the current batch - makes model training faster

val_ds = (ds_val
          .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
          .batch(batch_size)
          .prefetch(tf.data.AUTOTUNE))

test_ds = (ds_test
           .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
           .batch(batch_size)
           .prefetch(tf.data.AUTOTUNE))


# Data augmentation layers
data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.02),
    layers.RandomZoom(0.2, 0.2),
], name="data_augmentation")

In [None]:
# Build ViT model (frozen backbone)

def vit_model(in_shape=(image_size, image_size, 3), num_classes=num_classes):
    inputs = keras.Input(shape=in_shape)
    x = data_augmentation(inputs)  # apply augmentation
    # Pretrained ViT classifier
    vit = keras_hub.models.ViTImageClassifier.from_preset(
        "vit_base_patch16_224_imagenet",
        num_classes=num_classes
    )
    vit.backbone.trainable = False  # freeze backbone
    outputs = vit(x)
    model = keras.Model(inputs, outputs, name="vit_flowers_frozen")
    return model

model = vit_model()

Downloading from https://www.kaggle.com/api/v1/models/keras/vit/keras/vit_base_patch16_224_imagenet/2/download/config.json...


100%|██████████| 593/593 [00:00<00:00, 1.19MB/s]


In [None]:
# Compile model

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=["accuracy"]
)

model.summary()

In [None]:
# Train model

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

Epoch 1/5
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 587ms/step - accuracy: 0.2573 - loss: 2.0269 - val_accuracy: 0.5808 - val_loss: 1.1386
Epoch 2/5
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 580ms/step - accuracy: 0.6306 - loss: 1.0656 - val_accuracy: 0.8022 - val_loss: 0.6890
Epoch 3/5
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m48s[0m 593ms/step - accuracy: 0.7879 - loss: 0.6783 - val_accuracy: 0.8584 - val_loss: 0.4988
Epoch 4/5
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 586ms/step - accuracy: 0.8414 - loss: 0.5264 - val_accuracy: 0.8820 - val_loss: 0.4012
Epoch 5/5
[1m81/81[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m49s[0m 603ms/step - accuracy: 0.8619 - loss: 0.4525 - val_accuracy: 0.8966 - val_loss: 0.3417


In [None]:
# Evaluate on test set

test_loss, test_acc = model.evaluate(test_ds)
print("Final Test Accuracy:", test_acc)

[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 443ms/step - accuracy: 0.9193 - loss: 0.2910
Final Test Accuracy: 0.918181836605072
