# Vision Transformer (ViT) from Scratch

This notebook implements a Vision Transformer based on the paper:
["An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"](https://arxiv.org/abs/2010.11929)

## Architecture Overview
1. **Patch Embedding**: Split image into fixed-size patches and linearly embed them
2. **Position Embedding**: Add learnable position embeddings
3. **Transformer Encoder**: Stack of multi-head self-attention and MLP blocks
4. **Classification Head**: MLP head on the [CLS] token for classification

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## 1. Patch Embedding

Convert an image into a sequence of flattened patches, then project to embedding dimension.

In [5]:
class PatchEmbedding(nn.Module):
    """Split image into patches and embed them.
    
    Args:
        img_size: Size of input image (assumed square)
        patch_size: Size of each patch (assumed square)
        in_channels: Number of input channels (3 for RGB)
        embed_dim: Embedding dimension
    """
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        
        # Conv2d is equivalent to splitting into patches and linear projection
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, n_patches_h, n_patches_w)
        x = x.flatten(2)  # (B, embed_dim, n_patches)
        x = x.transpose(1, 2)  # (B, n_patches, embed_dim)
        return x

## 2. Multi-Head Self-Attention

The core attention mechanism that allows patches to attend to each other.

In [6]:
class MultiHeadAttention(nn.Module):
    """Multi-head self-attention mechanism.
    
    Args:
        embed_dim: Embedding dimension
        n_heads: Number of attention heads
        dropout: Dropout rate
    """
    def __init__(self, embed_dim=768, n_heads=12, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        
        assert embed_dim % n_heads == 0, "embed_dim must be divisible by n_heads"
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.scale = self.head_dim ** -0.5
    
    def forward(self, x):
        B, N, C = x.shape
        
        # Generate Q, K, V
        qkv = self.qkv(x).reshape(B, N, 3, self.n_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, n_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        # Attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, n_heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to values
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        
        return x

## 3. MLP (Feed-Forward Network)

Two-layer MLP with GELU activation, applied after attention in each transformer block.

In [7]:
class MLP(nn.Module):
    """Feed-forward network with GELU activation.
    
    Args:
        embed_dim: Input/output dimension
        mlp_ratio: Ratio to determine hidden dimension
        dropout: Dropout rate
    """
    def __init__(self, embed_dim=768, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        hidden_dim = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

## 4. Transformer Encoder Block

A single transformer block: LayerNorm -> Attention -> Residual -> LayerNorm -> MLP -> Residual

In [8]:
class TransformerBlock(nn.Module):
    """Transformer encoder block with pre-norm architecture.
    
    Args:
        embed_dim: Embedding dimension
        n_heads: Number of attention heads
        mlp_ratio: MLP hidden dimension ratio
        dropout: Dropout rate
    """
    def __init__(self, embed_dim=768, n_heads=12, mlp_ratio=4.0, dropout=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, n_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, dropout)
    
    def forward(self, x):
        # Pre-norm architecture (as in original ViT)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

## 5. Complete Vision Transformer

Putting it all together: Patch embedding + [CLS] token + Position embedding + Transformer blocks + Classification head

In [9]:
class VisionTransformer(nn.Module):
    """Vision Transformer for image classification.
    
    Args:
        img_size: Input image size (assumed square)
        patch_size: Patch size (assumed square)
        in_channels: Number of input channels
        n_classes: Number of classification classes
        embed_dim: Embedding dimension
        depth: Number of transformer blocks
        n_heads: Number of attention heads
        mlp_ratio: MLP hidden dimension ratio
        dropout: Dropout rate
    """
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        n_classes=1000,
        embed_dim=768,
        depth=12,
        n_heads=12,
        mlp_ratio=4.0,
        dropout=0.0
    ):
        super().__init__()
        
        # Patch embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        n_patches = self.patch_embed.n_patches
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        
        # Learnable position embeddings (for patches + CLS token)
        self.pos_embed = nn.Parameter(torch.zeros(1, n_patches + 1, embed_dim))
        self.pos_dropout = nn.Dropout(dropout)
        
        # Transformer encoder blocks
        self.blocks = nn.Sequential(*[
            TransformerBlock(embed_dim, n_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])
        
        # Final layer norm
        self.norm = nn.LayerNorm(embed_dim)
        
        # Classification head
        self.head = nn.Linear(embed_dim, n_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        # Initialize position embeddings and CLS token
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        # Initialize linear layers and layer norms
        self.apply(self._init_module_weights)
    
    def _init_module_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, n_patches, embed_dim)
        
        # Prepend [CLS] token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)  # (B, n_patches + 1, embed_dim)
        
        # Add position embeddings
        x = x + self.pos_embed
        x = self.pos_dropout(x)
        
        # Transformer blocks
        x = self.blocks(x)
        
        # Final norm and classification
        x = self.norm(x)
        cls_output = x[:, 0]  # Take [CLS] token output
        logits = self.head(cls_output)
        
        return logits

## 6. Model Variants

Helper functions to create standard ViT variants (ViT-Tiny, ViT-Small, ViT-Base, etc.)

In [10]:
def vit_tiny(img_size=224, patch_size=16, n_classes=1000, **kwargs):
    """ViT-Tiny: 5.7M parameters"""
    return VisionTransformer(
        img_size=img_size, patch_size=patch_size, n_classes=n_classes,
        embed_dim=192, depth=12, n_heads=3, **kwargs
    )

def vit_small(img_size=224, patch_size=16, n_classes=1000, **kwargs):
    """ViT-Small: 22M parameters"""
    return VisionTransformer(
        img_size=img_size, patch_size=patch_size, n_classes=n_classes,
        embed_dim=384, depth=12, n_heads=6, **kwargs
    )

def vit_base(img_size=224, patch_size=16, n_classes=1000, **kwargs):
    """ViT-Base: 86M parameters"""
    return VisionTransformer(
        img_size=img_size, patch_size=patch_size, n_classes=n_classes,
        embed_dim=768, depth=12, n_heads=12, **kwargs
    )

def vit_large(img_size=224, patch_size=16, n_classes=1000, **kwargs):
    """ViT-Large: 307M parameters"""
    return VisionTransformer(
        img_size=img_size, patch_size=patch_size, n_classes=n_classes,
        embed_dim=1024, depth=24, n_heads=16, **kwargs
    )

## 7. Test the Model

Let's verify our implementation works with a dummy input.

In [11]:
# Create a small ViT for testing
model = vit_tiny(img_size=32, patch_size=4, n_classes=10)
model = model.to(device)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

# Test forward pass
dummy_input = torch.randn(2, 3, 32, 32).to(device)
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

Model parameters: 5,362,762
Input shape: torch.Size([2, 3, 32, 32])
Output shape: torch.Size([2, 10])


## 8. Train on CIFAR-10

Let's train our ViT on CIFAR-10 to verify it learns.

In [None]:
# Data augmentation and normalization
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
])

# Load CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(train_loader), 100. * correct / total


def evaluate(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(test_loader), 100. * correct / total

In [None]:
# Create model for CIFAR-10 (32x32 images, 10 classes)
model = vit_tiny(img_size=32, patch_size=4, n_classes=10, dropout=0.1)
model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)

print(f"Training ViT-Tiny on CIFAR-10")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
n_epochs = 20
train_losses, test_losses = [], []
train_accs, test_accs = [], []

for epoch in range(n_epochs):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_accs.append(train_acc)
    test_accs.append(test_acc)
    
    print(f"Epoch {epoch+1:2d}/{n_epochs} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | "
          f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses, label='Train')
ax1.plot(test_losses, label='Test')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()

ax2.plot(train_accs, label='Train')
ax2.plot(test_accs, label='Test')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()

plt.tight_layout()
plt.show()

## 9. Visualize Attention Maps

Let's visualize what the model is attending to.

In [None]:
def get_attention_maps(model, x):
    """Extract attention maps from the model."""
    attention_maps = []
    
    # Register hooks to capture attention weights
    def hook_fn(module, input, output):
        # Get attention weights before softmax
        B, N, C = input[0].shape
        qkv = module.qkv(input[0]).reshape(B, N, 3, module.n_heads, module.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * module.scale
        attn = attn.softmax(dim=-1)
        attention_maps.append(attn.detach().cpu())
    
    hooks = []
    for block in model.blocks:
        hook = block.attn.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    model.eval()
    with torch.no_grad():
        _ = model(x)
    
    for hook in hooks:
        hook.remove()
    
    return attention_maps

In [None]:
# Get a sample image
sample_img, sample_label = test_dataset[0]
sample_img_batch = sample_img.unsqueeze(0).to(device)

# Get attention maps
attn_maps = get_attention_maps(model, sample_img_batch)

# Visualize attention from CLS token to patches in the last layer
last_attn = attn_maps[-1][0]  # (n_heads, N, N)
cls_attn = last_attn[:, 0, 1:]  # Attention from CLS to patches

# Average across heads
cls_attn_avg = cls_attn.mean(0)  # (n_patches,)
n_patches_side = int(cls_attn_avg.shape[0] ** 0.5)
cls_attn_map = cls_attn_avg.reshape(n_patches_side, n_patches_side)

# Plot
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Original image
img_np = sample_img.permute(1, 2, 0).numpy()
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min())
axes[0].imshow(img_np)
axes[0].set_title(f'Original (Label: {sample_label})')
axes[0].axis('off')

# Attention map
axes[1].imshow(cls_attn_map.numpy(), cmap='viridis')
axes[1].set_title('CLS Attention (Last Layer)')
axes[1].axis('off')

# Overlay
attn_resized = np.array(cls_attn_map)
attn_resized = np.kron(attn_resized, np.ones((4, 4)))  # Upsample to match image size
axes[2].imshow(img_np)
axes[2].imshow(attn_resized, cmap='jet', alpha=0.5)
axes[2].set_title('Attention Overlay')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## 10. Experiment Ideas

Now that you have a working ViT, here are some things to try:

1. **Different patch sizes**: Try `patch_size=2` or `patch_size=8` and see how it affects performance
2. **Model scaling**: Compare `vit_tiny`, `vit_small`, and `vit_base` on CIFAR-10
3. **Data augmentation**: Add more augmentation like `RandAugment` or `Mixup`
4. **Learning rate warmup**: Add warmup for more stable training
5. **Different datasets**: Try CIFAR-100 or download a subset of ImageNet
6. **Positional encoding variations**: Try sinusoidal instead of learned position embeddings
7. **Attention visualization**: Compare attention patterns at different layers

In [None]:
# Example: Compare different patch sizes
for patch_size in [2, 4, 8]:
    test_model = vit_tiny(img_size=32, patch_size=patch_size, n_classes=10)
    n_params = sum(p.numel() for p in test_model.parameters())
    n_patches = (32 // patch_size) ** 2
    print(f"Patch size: {patch_size} | Patches: {n_patches} | Parameters: {n_params:,}")