In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# Define the ViT architecture
class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, hidden_size, num_heads, num_layers):
        super(ViT, self).__init__()
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, hidden_size, kernel_size=patch_size, stride=patch_size)
        self.transformer_encoder = nn.TransformerEncoder(nn.TransformerEncoderLayer(hidden_size, num_heads), num_layers)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x)
        x = x.flatten(2).transpose(1, 2)
        x = self.transformer_encoder(x)
        x = x.mean(1)
        x = self.fc(x)
        return x

# Load the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)

# Instantiate the ViT model
model = ViT(image_size=32, patch_size=4, num_classes=10, hidden_size=256, num_heads=8, num_layers=6)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train the model for 10 epochs
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:12<00:00, 13481051.05it/s]


Extracting ./data\cifar-10-python.tar.gz to ./data
[1,   100] loss: 2.375
[1,   200] loss: 2.318
[1,   300] loss: 2.316
[2,   100] loss: 2.313
[2,   200] loss: 2.315
[2,   300] loss: 2.311
[3,   100] loss: 2.311
[3,   200] loss: 2.313
[3,   300] loss: 2.309
[4,   100] loss: 2.308
[4,   200] loss: 2.310
[4,   300] loss: 2.308
[5,   100] loss: 2.308
[5,   200] loss: 2.309
[5,   300] loss: 2.307
[6,   100] loss: 2.307
[6,   200] loss: 2.307
[6,   300] loss: 2.307
[7,   100] loss: 2.305
[7,   200] loss: 2.307
[7,   300] loss: 2.305
[8,   100] loss: 2.306
[8,   200] loss: 2.305
[8,   300] loss: 2.305
[9,   100] loss: 2.305
[9,   200] loss: 2.305


KeyboardInterrupt: 