In [7]:
import torch
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

# Define the ViT model
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim):
        super().__init__()
        assert image_size % patch_size == 0, "image size must be divisible by patch size"
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size ** 2
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.patch_embedding = nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size)
        self.transformer = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=dim, nhead=4), num_layers=12)
        self.classification_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        patches = self.patch_embedding(x).flatten(2).transpose(1, 2)
        embeddings = self.transformer(patches)
        cls_embedding = embeddings[:, 0]
        return self.classification_head(cls_embedding)

# Define the training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
num_epochs = 10
learning_rate = 0.001

# Load the CIFAR10 dataset
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_dataset = torchvision.datasets.CIFAR10(root='~/data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)


#num_training_steps = num_epochs * len(train_loader)
#warmup_steps = 500
#cosine_decay_steps = num_training_steps - warmup_steps

# Initialize the model, optimizer, and loss function
model = ViT(image_size=32, patch_size=4, num_classes=10, dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Define the learning rate scheduler
#lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=cosine_decay_steps, T_mult=1, eta_min=0.0001)
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print(f"Number of model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f} million")
print(f"Running on: {device}")

# Evaluate the model on the test dataset
test_dataset = torchvision.datasets.CIFAR10(root='~/data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

def evaluate(model, test_loader, test_dataset):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Accuracy of the model on the {len(test_dataset)} test images: {100 * correct / total}%")


# Train the model

for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(tqdm(train_loader)):
        # Move the images and labels to the device
        images = images.to(device)
        labels = labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Update the learning rate
    
    lr_scheduler.step()
    # Print the loss and accuracy after each epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, LR: {optimizer.param_groups[0]['lr']:.4f}")
    evaluate(model, test_loader, test_dataset)



Files already downloaded and verified
Number of model parameters: 3.38 million
Running on: cuda
Files already downloaded and verified


100%|██████████| 782/782 [00:45<00:00, 17.01it/s]


Epoch [1/10], Loss: 2.3065, LR: 0.0010
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.02it/s]


Epoch [2/10], Loss: 2.3067, LR: 0.0009
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:46<00:00, 16.90it/s]


Epoch [3/10], Loss: 2.2931, LR: 0.0008
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:46<00:00, 16.88it/s]


Epoch [4/10], Loss: 2.3023, LR: 0.0007
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:46<00:00, 16.78it/s]


Epoch [5/10], Loss: 2.3356, LR: 0.0005
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.19it/s]


Epoch [6/10], Loss: 2.3051, LR: 0.0003
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.03it/s]


Epoch [7/10], Loss: 2.2979, LR: 0.0002
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.24it/s]


Epoch [8/10], Loss: 2.2996, LR: 0.0001
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.19it/s]


Epoch [9/10], Loss: 2.3053, LR: 0.0000
Accuracy of the model on the 10000 test images: 10.0%


100%|██████████| 782/782 [00:45<00:00, 17.27it/s]


Epoch [10/10], Loss: 2.3023, LR: 0.0000
Accuracy of the model on the 10000 test images: 10.0%
