In [1]:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR100
import torchvision.transforms as transforms

# Data loading and splitting
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for CIFAR-100
])

full_dataset = CIFAR100(root="./data", train=True, download=True, transform=transform)

# Split the dataset into training, validation, and testing sets
train_size = int(0.7 * len(full_dataset))
val_size = int(0.2 * len(full_dataset))
test_size = len(full_dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])

batch_size = 256

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:12<00:00, 13513895.54it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


In [4]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16
from torchvision.models import ViT_B_16_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 100  # CIFAR-100 has 100 classes
lr = 0.003 #1e-4
patch_size = 4 #8
num_layers = 12
num_heads = 16 #12
hidden_dim = 512 #768
mlp_dim = 3072
image_size = 32
# max_len = 100 # All sequences must be less than 1000 including class token
# channels = 3

num_epochs = 30

# Model, loss, and optimizer
model = vit_b_16(weights=ViT_B_16_Weights).to(device)

# Freeze most weights
for name, param in model.named_parameters():
    if not name.startswith('fc'):  # Exclude the final classification layer
        param.requires_grad = False

# Define the final classification layer to fine-tune
model.fc = nn.Linear(hidden_dim, num_classes).to(device)
model.fc.requires_grad = True  # Unfreeze the final layer

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=5, gamma=0.5)  # Optional: learning rate scheduler

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        with torch.set_grad_enabled(True):  # Enable gradient computation
            outputs = model(inputs)
            loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.requires_grad = True
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optional: Adjust learning rate
    scheduler.step()

    # Print average loss per epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {average_loss:.4f}")

    # Validation loop
    model.eval()
    correct = 0
    total = 0
    total_val_loss = 0

    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)

            # Ensure gradient computation is disabled during validation
            with torch.set_grad_enabled(False):
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                loss = criterion(outputs, labels)

                total_val_loss += loss.item()

    val_accuracy = correct / total
    average_val_loss = total_val_loss / len(val_loader)
    print(f"Validation Accuracy: {val_accuracy * 100:.2f}%, , Validation Loss: {average_val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), "vision_transformer_model_finetuned.pth")

Epoch 1/30: 100%|██████████| 137/137 [07:40<00:00,  3.36s/it]


Epoch 1/30, Loss: 8.7368


Validation: 100%|██████████| 40/40 [01:57<00:00,  2.94s/it]


Validation Accuracy: 0.52%, , Validation Loss: 8.7409


Epoch 2/30:   1%|          | 1/137 [00:07<17:54,  7.90s/it]


KeyboardInterrupt: ignored

In [None]:
# Evaluation loop
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")