# Review of google-research/vision_transformer
GitHub: https://github.com/google-research/vision_transformer

In [1]:
import torch

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device name:", torch.cuda.get_device_name(0))
    print("Device count:", torch.cuda.device_count())
    print("Current device:", torch.cuda.current_device())

Torch version: 2.6.0+cu124
CUDA available: True
Device name: Tesla T4
Device count: 1
Current device: 0


In [10]:
import torch
import timm
import torchvision
import torchvision.transforms as transforms
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torch.utils.data import random_split

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])


full_train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_size = int(0.9 * len(full_train_set))
val_size = len(full_train_set) - train_size
train_set, val_set = random_split(full_train_set, [train_size, val_size])
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

subset_indices = torch.arange(1000).tolist()
train_subset = Subset(train_set, subset_indices)
test_subset = Subset(test_set, subset_indices)
val_subset = Subset(val_set, subset_indices)

train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
val_loader = DataLoader(val_subset, batch_size=32, shuffle=False)

model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = nn.Linear(model.head.in_features, 10) # a transfer learning for 10 classes
model.to(device)

# Freeze all parameters except the final classification head
for param in model.parameters():
    param.requires_grad = False

# Enable training only for the last layer
for param in model.head.parameters():
    param.requires_grad = True

# Check trainable parameters to verify
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f'Trainable: {name}')

# Count trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Total trainable parameters: {trainable_params}')


Using device: cuda
Trainable: head.weight
Trainable: head.bias
Total trainable parameters: 7690


In [2]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 10

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_accuracy = 100 * correct_train / total_train
    print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

    # Validation
    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    val_accuracy = 100 * correct_val / total_val
    print(f"Epoch {epoch+1}/{num_epochs} - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Testing
    correct_test = 0
    total_test = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total_test += labels.size(0)
            correct_test += (predicted == labels).sum().item()

    test_accuracy = 100 * correct_test / total_test
    print(f"Epoch {epoch+1}/{num_epochs} - Test Accuracy: {test_accuracy:.2f}%\n")

Epoch 1/10 - Train Loss: 72.6494, Train Accuracy: 18.20%
Epoch 1/10 - Validation Loss: 59.8302, Validation Accuracy: 36.80%
Epoch 1/10 - Test Accuracy: 39.10%

Epoch 2/10 - Train Loss: 51.1208, Train Accuracy: 56.10%
Epoch 2/10 - Validation Loss: 42.6098, Validation Accuracy: 71.70%
Epoch 2/10 - Test Accuracy: 72.50%

Epoch 3/10 - Train Loss: 36.1759, Train Accuracy: 79.80%
Epoch 3/10 - Validation Loss: 30.7975, Validation Accuracy: 85.40%
Epoch 3/10 - Test Accuracy: 85.70%

Epoch 4/10 - Train Loss: 26.1716, Train Accuracy: 89.50%
Epoch 4/10 - Validation Loss: 23.1346, Validation Accuracy: 90.60%
Epoch 4/10 - Test Accuracy: 89.20%

Epoch 5/10 - Train Loss: 19.9073, Train Accuracy: 93.30%
Epoch 5/10 - Validation Loss: 18.1755, Validation Accuracy: 92.50%
Epoch 5/10 - Test Accuracy: 91.70%

Epoch 6/10 - Train Loss: 15.9705, Train Accuracy: 95.50%
Epoch 6/10 - Validation Loss: 14.9198, Validation Accuracy: 93.80%
Epoch 6/10 - Test Accuracy: 93.20%

Epoch 7/10 - Train Loss: 12.8852, Train 