# Init

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
import os


# Define device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters

In [None]:
batch_size = 32
learning_rate = 1e-4
num_epochs = 5
checkpoint_dir = './checkpoints'

# Create checkpoint directory if it doesn't exist

In [None]:
os.makedirs(checkpoint_dir, exist_ok=True)

# Load dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Load Vision Transformer (ViT)

In [None]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=10)
model.to(device)

# Define loss function and optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Function to evaluate model

In [None]:
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images).logits
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total


# Training loop with checkpointing

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images).logits
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'vit_epoch_{epoch + 1}.pth')
    torch.save({'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': running_loss}, checkpoint_path)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {running_loss:.4f}, Checkpoint saved at {checkpoint_path}")

    # Evaluate the model
    accuracy = evaluate(model, test_loader)
    print(f"Validation Accuracy after Epoch {epoch + 1}: {accuracy * 100:.2f}%")

print("Training complete!")