# Training a Vision Transformer (ViT) Classifier with Transfer Learning

This notebook demonstrates how to train an image classifier using PyTorch with a V**ision Transformer (ViT) backbone**. We’ll leverage **transfer learning** by starting from pretrained weights and fine-tuning the model for our custom dataset.\n

To make the training process efficient and reliable, we’ll incorporate **callbacks** such as:

**Best Model Checkpointing** – automatically save the model state with the lowest validation loss.

**Early Stopping** – stop training when validation performance stops improving, preventing overfitting and wasted compute.

By the end of this notebook, you’ll have a PyTorch-based image classification pipeline that’s modular, reproducible, and ready for experimentation on custom datasets.

In [None]:
# Connect to google drive in case your data is there.
from google.colab import drive
drive.mount('/content/gdrive')

try:
  !ln -s /content/gdrive/My\ Drive/ /mydrive
  print('Successful')
except Exception as e:
  print(e)
  print('Not successful')

In [None]:
!pip install torchinfo

In [None]:
!git clone https://github.com/AarohiSingla/Image-Classification-Using-Vision-transformer.git
%cd Image-Classification-Using-Vision-transformer

In [None]:
# Get the functions for training with callbacks.
url = (
    "https://raw.githubusercontent.com/tensorflow/models/refs/heads/master/"
    "official/projects/waste_identification_ml/fine_tuning/"
    "Pytorch_Image_Classifier/training_with_callbacks.py"
)
!wget {url} > /dev/null 2>&1

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
import glob
import requests
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from torchinfo import summary

from helper_functions import set_seeds
from helper_functions import plot_loss_curves
from going_modular.going_modular.predictions import pred_and_plot_image

from training_with_callbacks import EarlyStopping
import training_with_callbacks

device = "cuda" if torch.cuda.is_available() else "cpu"
num_workers = os.cpu_count()

In [None]:
#@title Utils

## And now we've got transforms ready, we can turn our images into DataLoaders using the create_dataloaders()
def create_dataloaders(
    train_dir: str,
    test_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    num_workers: int=NUM_WORKERS
):

  # Use ImageFolder to create dataset(s)
  train_data = datasets.ImageFolder(train_dir, transform=transform)
  test_data = datasets.ImageFolder(test_dir, transform=transform)

  # Get class names
  class_names = train_data.classes

  # Turn images into data loaders
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, test_dataloader, class_names

In [None]:
# Get pretrained weights for ViT-Base.
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT

# Setup a ViT model instance with pretrained weights.
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# Freeze the base parameters.
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Change the classifier head. In our case we have 2 categories "dairy" and "others".
class_names = ['dairy', 'others']

set_seeds()

pretrained_vit.heads = nn.Linear(in_features=768, out_features=len(class_names)).to(device)

In [None]:

# Print a summary using torchinfo (uncomment for actual output)
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

print("Notice how only the output layer is trainable, where as, all of the rest of the layers are untrainable (frozen).")

## Dataset

Image Classifier training expects the data to be in the format below. Divide the dataset into `train`, `valid` and `test` folders. Each category should have the labeling folder with their corresponding images. Folder names should be corresponding to the label names used while training.

```
dataset/
├── train/
│   ├── category_1
                ├── Images
│   ├── category_2
                ├── Images
├── valid/
│   ├── Category_1
                ├── Images
│   ├── Category_2
                ├── Images
└── test/
    ├── Category_1
                ├── Images
    ├── Category_2
                ├── Images

```



In [2]:
# Setup directory paths to train and test images
train_dir = '/content/Image-Classification-Using-Vision-transformer/train' # @param {type: "string", placeholder: "[train_dir]", isTemplate: true}
valid_dir = '/content/Image-Classification-Using-Vision-transformer/val'    # @param {type: "string", placeholder: "[valid_dir]", isTemplate: true}

In [None]:
# Remember, if you're going to use a pretrained model, it's generally important
# to ensure your own custom data is transformed/formatted in the same way the
# data the original model was trained on.
# Get automatic transforms from pretrained ViT weights
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

In [None]:
# Setup dataloaders
batch_size = 64
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders(train_dir=train_dir,
                                                                                                     test_dir=valid_dir,
                                                                                                     transform=pretrained_vit_transforms,
                                                                                                     batch_size=batch_size)

In [None]:
# Let's visualize a image in order to know if data is loaded properly or not

# Get a batch of images
image_batch, label_batch = next(iter(train_dataloader_pretrained))

# Get a single image from the batch
image, label = image_batch[0], label_batch[0]

# View the batch shapes.
print(image.shape, label)

plt.imshow(image.permute(1, 2, 0))
plt.title(class_names[label])
plt.axis(False);

In [None]:
model_output_path = "/mydrive/LLM/pet_grade_bottles/best_vit_model" # @param {type: "string", placeholder: "[model output]", isTemplate: true}

early_stopper = EarlyStopping(
    patience=5,
    delta=0.001,
    verbose=True,
    base_path=model_output_path
)

In [None]:
# Create loss function.
# OPTIONAL - Calculate weights to counteract class imbalance:
# Formula: weight = total_samples / (num_classes * samples_per_class)
# Example: weights = torch.tensor([42000 / (2 * 19000), 42000 / (2 * 23000)]).to(device)
#          loss_fn = torch.nn.CrossEntropyLoss(weights=weights)
loss_fn = torch.nn.CrossEntropyLoss()

# Create Optimizer and Scehduler.
optimizer = torch.optim.AdamW(params=pretrained_vit.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

In [None]:
pretrained_vit_results = training_with_callbacks.train(
    model=pretrained_vit,
    train_dataloader=train_dataloader_pretrained,
    test_dataloader=test_dataloader_pretrained,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=200,
    device=device,
    early_stopping=early_stopper,
    scheduler=scheduler
)

# Plot the loss curves
plot_loss_curves(pretrained_vit_results)