In [10]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import timm
from torchvision.datasets import CIFAR100
from sklearn.model_selection import train_test_split
from PIL import Image

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
#using cifar-100 instead of cifar-10
cifar_dataset = CIFAR100(root="./data", train=True, transform=transform, download=True)  # Use CIFAR100
subset_size = 5000

cifar_dataset = torch.utils.data.random_split(cifar_dataset, [subset_size, len(cifar_dataset) - subset_size])[0]

train_set, val_set = train_test_split(cifar_dataset, test_size=0.1, random_state=42)
batch_size = 32

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data\cifar-100-python.tar.gz


100%|█████████████████████████████████████████████████████████████████| 169001437/169001437 [41:12<00:00, 68340.47it/s]


Extracting ./data\cifar-100-python.tar.gz to ./data


In [6]:
class ModifiedViT(nn.Module):
    def __init__(self, num_classes=100):  # Change num_classes to 100
        super(ModifiedViT, self).__init__()

        self.model = timm.create_model("vit_small_patch16_224", pretrained=True)
        self.model.head = nn.Linear(self.model.head.in_features, num_classes)

    def forward(self, x):
        return self.model(x)

modified_model = ModifiedViT().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(modified_model.parameters(), lr=2e-4)

num_epochs = 5
accumulation_steps = 4


model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [7]:
for epoch in range(num_epochs):
    modified_model.train()
    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = modified_model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    modified_model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = modified_model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        accuracy = correct / total
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.4f}, Acc: {accuracy:.4f}")

Epoch 1/5, Loss: 4.1944, Acc: 0.0900
Epoch 2/5, Loss: 3.1988, Acc: 0.3220
Epoch 3/5, Loss: 1.4789, Acc: 0.4260
Epoch 4/5, Loss: 1.3204, Acc: 0.5220
Epoch 5/5, Loss: 1.5496, Acc: 0.5920


In [18]:
def classify_modified_image(model, image_path, transform, class_labels):
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)

    model.eval()
    with torch.no_grad():
        output = model(image)

    _, predicted_class_index = torch.max(output, 1)

    predicted_class_label = class_labels[predicted_class_index.item()]

    return predicted_class_label

In [19]:
cifar100_classes = cifar100_info.classes

image_path = r'F:\Uni Work\CV Lab\Lab 11\dog.jpg'
predicted_class_label = classify_modified_image(modified_model, image_path, transform, cifar100_classes)
print(f"The predicted class label for the image is: {predicted_class_label}")

The predicted class label for the image is: bear
