In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.optim as optim

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (batch_size, embed_dim, num_patches, num_patches)
        x = x.flatten(2)  # Flatten (batch_size, embed_dim, num_patches)
        x = x.transpose(1, 2)  # (batch_size, num_patches, embed_dim)
        return x

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)

    def forward(self, x):
        # Attention needs (sequence_length, batch_size, embed_dim)
        x = x.transpose(0, 1)  # (num_patches, batch_size, embed_dim)
        attn_output, _ = self.attention(x, x, x)
        return attn_output.transpose(0, 1)  # (batch_size, num_patches, embed_dim)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, hidden_dim, dropout=0.1):
        super(TransformerEncoderLayer, self).__init__()
        self.mhsa = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim),
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # MHSA block
        attn_output = self.mhsa(x)
        x = self.norm1(x + self.dropout(attn_output))  # Add & Norm

        # FFN block
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))  # Add & Norm
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000,
                 embed_dim=768, num_heads=12, depth=12, hidden_dim=3072):
        super(VisionTransformer, self).__init__()

        # Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.num_patches = self.patch_embed.num_patches

        # Class token and Positional Embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(0.1)

        # Transformer Encoder
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderLayer(embed_dim, num_heads, hidden_dim)
            for _ in range(depth)
        ])

        # Classification Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)

        # Patch + Position Embedding
        x = self.patch_embed(x)  # (batch_size, num_patches, embed_dim)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # (batch_size, 1, embed_dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (batch_size, num_patches + 1, embed_dim)
        x = x + self.pos_embed
        x = self.pos_dropout(x)

        # Transformer Encoder
        for layer in self.encoder_layers:
            x = layer(x)

        # Classification Head
        cls_token_final = x[:, 0]  # Take only the CLS token for classification
        out = self.mlp_head(cls_token_final)
        return out

In [3]:
# # Instantiate the model
# vit = VisionTransformer(img_size=224, patch_size=16, num_classes=10)

# # Test the model with a dummy input
# dummy_input = torch.randn(1, 3, 224, 224)  # (batch_size, channels, height, width)
# output = vit(dummy_input)
# print(output.shape)  # Expected output: (1, 10)

torch.Size([1, 10])


In [4]:
# Data Transformations (Resizing CIFAR-10 to fit ViT's expected input size)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 as required by ViT
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Download and load the CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:04<00:00, 42.1MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the ViT model and move it to the device
vit_model = VisionTransformer(img_size=224, patch_size=16, num_classes=10).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=0.001)

# Training Loop
def train_vit(model, train_loader, criterion, optimizer, device, epochs=5):
    model.train()  # Set the model to training mode
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Track statistics
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        # Print stats after each epoch
        epoch_loss = running_loss / len(train_loader)
        accuracy = 100 * correct / total
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

# Train the model for 5 epochs
train_vit(vit_model, train_loader, criterion, optimizer, device, epochs=5)

Epoch [1/5], Loss: 2.3619, Accuracy: 10.14%
Epoch [2/5], Loss: 2.3085, Accuracy: 10.07%
Epoch [3/5], Loss: 2.3047, Accuracy: 10.10%
Epoch [4/5], Loss: 2.3049, Accuracy: 10.17%
Epoch [5/5], Loss: 2.3044, Accuracy: 10.05%


In [6]:
def evaluate_vit(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

# Evaluate the model
evaluate_vit(vit_model, test_loader, device)

Test Accuracy: 10.00%
