# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, ColorJitter
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import numpy as np
import os


# Device
DEVICE = print(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

cuda


# Hyperparameters

In [12]:
IMG_SIZE = 224
BATCH_SIZE = 4
NUM_CLASSES = 37
EPOCHS = 5


# Dataset - Oxford-IIIT Pet

In [18]:
transform_train= Compose([
Resize((IMG_SIZE, IMG_SIZE)),
RandomHorizontalFlip(),
#ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
ToTensor(),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

transform_test= Compose([
Resize((IMG_SIZE, IMG_SIZE)),
ToTensor(),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [15]:
trainset = datasets.OxfordIIITPet(root="./data", download=True, transform=transform_train, target_types="category")
testset = datasets.OxfordIIITPet(root="./data", split="test", download=True, transform=transform_test, target_types="category")
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Loading ViT-L/16

In [21]:
model=timm.create_model("vit_base_patch16_224",pretrained=True)

# Interpolate Positional Embeddings

In [19]:
def interpolate_pos_embed(model, new_img_size=224):

  pos_embed = model.pos_embed
  cls_token = pos_embed[:, 0:1, :]
  patch_pos_embed = pos_embed[:, 1:, :]
  num_patches = model.patch_embed.num_patches
  orig_size = int(patch_pos_embed.shape[1] ** 0.5)
  new_size = int(num_patches ** 0.5)
  patch_pos_embed = patch_pos_embed.reshape(1, orig_size, orig_size, -1).permute(0, 3, 1, 2)
  patch_pos_embed = F.interpolate(patch_pos_embed, size=(new_size, new_size), mode='bilinear')
  patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).reshape(1, new_size * new_size, -1)
  model.pos_embed = nn.Parameter(torch.cat([cls_token, patch_pos_embed], dim=1))


interpolate_pos_embed(model)

# Classification Head

In [22]:
model.head = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.head.in_features, NUM_CLASSES)
)

# Training



In [23]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=2)

In [24]:
if os.path.exists("best_vit_pet.pth"):
  model.load_state_dict(torch.load("best_vit_pet.pth"))
  print("Loaded checkpoint")

Loaded checkpoint


In [25]:
# Train and save best checkpoint (no scheduler, no patience)
train_accs, test_accs = [], []
best_test_acc = 0

for epoch in range(EPOCHS):
    model.train()
    correct, total = 0, 0
    for x, y in trainloader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    train_acc = correct / total
    train_accs.append(train_acc)

    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in testloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x)
            preds = out.argmax(1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    test_acc = correct / total
    test_accs.append(test_acc)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Acc: {train_acc:.4f} | Test Acc: {test_acc:.4f}")

    # Save best model checkpoint
    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save(model.state_dict(), "best_vit_pet.pth")
        print(" Saved best model")


KeyboardInterrupt: 

# Evaluation

In [None]:
plt.plot(train_accs, label="Train")
plt.plot(test_accs, label="Test")
plt.title("ViT-L/16 on Oxford Pet Dataset")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid(True)
plt.show()

In [None]:
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
cm = confusion_matrix(all_labels, all_preds)
ConfusionMatrixDisplay.from_predictions(all_labels, all_preds, cmap="Blues", xticks_rotation=45)
plt.title("Confusion Matrix")
plt.show()

# Visualization

In [None]:
# Grad-CAM like attention visualization (requires ViT attention support)
def show_attention_map(model, dataloader):
    model.eval()
    with torch.no_grad():
        for x, _ in dataloader:
            x = x[:1].to(DEVICE)
            _ = model(x)  # forward pass
            if hasattr(model.blocks[-1].attn, 'get_attention_map'):
                attn = model.blocks[-1].attn.get_attention_map()
                attn_map = attn[0, :, 0, 1:].mean(0)  # [heads, N, N] -> mean head attention to patches
                attn_map = attn_map.reshape(int(attn_map.shape[0] ** 0.5), -1)
                attn_map = F.interpolate(
                    attn_map.unsqueeze(0).unsqueeze(0),
                    size=(IMG_SIZE, IMG_SIZE),
                    mode='bilinear'
                )[0, 0]
                img = x[0].permute(1, 2, 0).cpu().numpy()
                img = (img * 0.5 + 0.5).clip(0, 1)
                plt.imshow(img)
                plt.imshow(attn_map.cpu(), cmap='jet', alpha=0.4)
                plt.title("Attention Overlay")
                plt.axis("off")
                plt.show()
            break
