In [None]:
pip install torch torchvision

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import time

In [None]:
# Check if GPU (CUDA), MPS (Apple Silicon) is available; otherwise, use CPU
device = torch.device("cuda" if torch.cuda.is_available()
                      else "mps" if torch.backends.mps.is_available()
                      else "cpu")
print(f"Using device: {device}")

In [None]:
# 1️⃣ Hyperparameters
num_classes = 10          # 10 classes for MNIST digits (0-9)
num_epochs = 2            # Number of epochs (adjust as needed)
batch_size = 32
learning_rate = 0.0001    # Lower LR for fine-tuning
weight_decay = 1e-4       # Regularization factor

In [None]:
# 2️⃣ Define Transforms for MNIST
# MNIST images are 28x28 and grayscale.
# We resize them to 224x224 and convert to 3 channels (RGB) to match the ViT input requirements.
train_transform = transforms.Compose([
    transforms.Resize(224),  # Resize image to 224x224
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel RGB
    transforms.RandomHorizontalFlip(p=0.5),  # Apply random horizontal flip
    transforms.RandomRotation(15),  # Random rotation within ±15 degrees
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),  # Random translation
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize pixel values to [-1, 1]
])

test_transform = transforms.Compose([
    transforms.Resize(224),  # Resize image to 224x224
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel RGB
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize pixel values to [-1, 1]
])

In [None]:
# Load MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=train_transform)
test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)

# Create DataLoaders for training and testing
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

In [None]:
# 3️⃣ Load Pretrained ViT Model and Adjust the Classification Head
# We load a pretrained ViT model and modify its classification head to output 10 classes
model = models.vit_b_16(pretrained=True)
model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)  # Update the final layer for MNIST
model = model.to(device)

In [None]:
# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# 4️⃣ Training Function
def train():
    model.train()  # Set the model to training mode
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (images, labels) in enumerate(train_loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Calculate training accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        training_accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {training_accuracy:.2f}%")

# 5️⃣ Evaluation Function
def test():
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():  # Disable gradient computation for inference
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

# 6️⃣ Main Loop (Training & Testing)
if __name__ == "__main__":
    # Measure training time
    start_time = time.time()
    train()
    end_time = time.time()
    training_time = end_time - start_time
    print(f"Training time: {training_time:.2f} seconds")

    # Measure inference time on the test set
    start_time = time.time()
    test()
    end_time = time.time()
    inference_time = end_time - start_time
    print(f"Inference time: {inference_time:.2f} seconds for the entire test set")

    # Calculate per-sample inference time
    per_sample_inference_time = inference_time / len(test_dataset)
    print(f"Inference time per sample: {per_sample_inference_time:.6f} seconds")


In [None]:
import os
# Save the model to a file
model_path = "vit_mnist.pth"
torch.save(model.state_dict(), model_path)

# Measure the size of the model file
model_size = os.path.getsize(model_path) / (1024 * 1024)
print(f"Model size: {model_size:.2f} MB")