In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# 1️⃣ Hyperparameters
# -----------------------------
num_classes = 10
num_epochs = 5
batch_size = 32
learning_rate = 0.0001  # Lower LR for fine-tuning
weight_decay = 1e-4  # Regularization
momentum = 0.9  # For SGD (if used instead of Adam)

# -----------------------------
# 2️⃣ Load CIFAR-10 Dataset
# -----------------------------
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Resize to 224x224 (ViT input size)
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),  # Stronger augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize test images to 224x224
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

# -----------------------------
# 3️⃣ Load Pretrained ViT Model
# -----------------------------
model = models.vit_b_16(pretrained=True)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)  # Modify last layer for CIFAR-10
model = model.to(device)

# Loss Function & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# -----------------------------
# 4️⃣ Fine-Tuning Function
# -----------------------------
training_accuracies = []

def train():
    model.train()
    for epoch in range(num_epochs):
        correct = 0
        total = 0
        running_loss = 0.0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)  # Move to GPU
            outputs = model(images)  # Forward pass
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Compute training accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        training_accuracy = 100 * correct / total
        training_accuracies.append(training_accuracy)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {training_accuracy:.2f}%")

# -----------------------------
# 5️⃣ Evaluation Function
# -----------------------------
def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)  # Move to GPU
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

# -----------------------------
# 6️⃣ Main Loop (Train & Test)
# -----------------------------
if __name__ == "__main__":
    # Measure training time
    start_time = time.time()
    train()
    end_time = time.time()
    training_time = end_time - start_time
    print(f"Training time: {training_time:.2f} seconds")

    # Measure inference time
    start_time = time.time()
    test()
    end_time = time.time()
    inference_time = end_time - start_time
    print(f"Inference time: {inference_time:.2f} seconds for the entire test set")

    # Calculate per-sample inference time
    per_sample_inference_time = inference_time / len(test_dataset)
    print(f"Inference time per sample: {per_sample_inference_time:.6f} seconds")


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Epoch [1/5], Loss: 0.7835, Training Accuracy: 73.15%
Epoch [2/5], Loss: 0.6436, Training Accuracy: 77.72%
Epoch [3/5], Loss: 0.6023, Training Accuracy: 79.13%
Epoch [4/5], Loss: 0.5829, Training Accuracy: 79.77%
Epoch [5/5], Loss: 0.5676, Training Accuracy: 80.43%
Training time: 8424.16 seconds
Accuracy on the test set: 93.45%
Inference time: 106.41 seconds for the entire test set
Inference time per sample: 0.010641 seconds


In [None]:
import os
# Save the model to a file
model_path = "vit_cifar10.pth"
torch.save(model.state_dict(), model_path)

# Measure the size of the model file
model_size = os.path.getsize(model_path) / (1024 * 1024)  # Convert bytes to MB
print(f"Model size: {model_size:.2f} MB")

Model size: 327.38 MB
