### Implementing Vision Transformer from scratch



In [None]:
#Lets import the necessary libraries

import torch
import torch.nn as nn

In [None]:
# Hyperparameters
IMG_SIZE = 32 # CIFFAR image resolution (32 x 32)
PATCH_SIZE = 8 # Each patch will be 8 x 8 pixels
IN_CHANNELS = 3 #RGB
EMBED_DIM = 64
NUM_HEADS = 4
DEPTH = 4 # Number of transformer encoder blocks
NUM_CLASSES = 10
BATCH_SIZE = 64

### Patch Embedding

The first task in Vit is to convert the images into patches.
Lets create a patch embedding layer that converts the images into patches then to vector embeddings

**NOTE: why we are using conv2d?**


1. Task One: The "Cutter" (Splitting into Patches) ‚úÇÔ∏è

The `Conv2d` layer replaces the need for a manual loop to chop up the image. It uses the Stride and Kernel Size to determine how the image is divided.

‚Ä¢ The Logic: By setting `stride = kernel_size`, we force the filter to jump exactly the width of one patch after processing it. This ensures no overlap between patches.

‚Ä¢ Example:

  ‚Ä¢ Image Size: 16x16 pixels.

  ‚Ä¢ Kernel Size: 4 (This defines the patch size as 4x4).

  ‚Ä¢ Stride: 4 (The jump size).

The Process:

1. The filter lands on the first 4x4 block (top-left).

2. It processes it.

3. It jumps exactly 4 pixels to the right.

4. It lands perfectly on the next 4x4 block, skipping nothing and overlapping nothing.

Result: You get a grid of 16 patches (4 rows x 4 columns).

2. Task Two: The "Translator" (Creating Embeddings) üó£Ô∏è

We don't just want pixel grids; we want vectors (lists of numbers) that represent the content of those patches.

We define an Embedding Dimension (e.g., `embed_dim = 64`). This means we want every 4x4 patch to be summarized by exactly 64 numbers.

To do this, the `Conv2d` layer creates 64 separate filters. Each filter is a unique "feature detector."

‚Ä¢ The Filter Shape: Since your patch is 4x4 with 3 color channels (RGB), every single filter has the shape 4x4x3.

How the Embedding is built for ONE patch:

‚Ä¢ Filter 1: A block of weights (4x4x3). It overlays the patch, multiplies all the pixel values by its weights, sums them up, and produces 1 single number.

‚Ä¢ Filter 2: A different block of weights (4x4x3). It looks at the same patch and produces a 2nd number.

‚Ä¢ ... (repeating this process) ...

‚Ä¢ Filter 64: It looks at the patch and produces the 64th number.

The Result: For that single patch, you now have a stack of 64 numbers. That is your Patch Embedding Vector. üíé

3. The Grand Finale (The Output Shape) üèÅ

After the `Conv2d` operation finishes running over your 16x16 image:

1. Patches Created: 16 patches (arranged in a 4x4 grid).

2. Embedding per Patch: 64 numbers.

3. Final Output Shape: `(Batch_Size, 64, 4, 4)`.

This tensor contains 16 patches, where each patch is now represented by a deep vector of 64 features instead of raw pixels!

In [None]:
class PatchEmbedding(nn.Module):

  # for example, lets lake 16 * 16 pixel image and patch length as 4
  # img_size = 16 , patch_size = 4, channels = 3 (RGB channels)
  # embed_dim is size of the vector we want to transform the patches
    def __init__(self, img_size= IMG_SIZE,
                 patch_size = PATCH_SIZE,
                 in_channels = IN_CHANNELS,
                 embed_dim = EMBED_DIM):

      super().__init__()
      self.patch_size = patch_size


      # splitting the image into patches
      # we can split the img into patches using manual slicing
      # but convolution is more simpler shortcut
      self.projection = nn.Conv2d(in_channels, embed_dim,
                                  kernel_size = patch_size,
                                  stride = patch_size)

      self.num_patches = (img_size // patch_size) ** 2


    def forward(self, x):

      # before projection x.shape = [B, 1, 16, 16]
      x = self.projection(x)

      #after projection x.shape = [B, embed_dim, 4, 4]

      x = x.flatten(2)

      #after flatten x.shape = [B, embed_dim, 16]

      x = x.transpose(1, 2)

      #after transpose x.shape = [B,16, embed_dim]
      return x

### Encoder block implementation

Each encoder block contains layer normalization, multi lead attention and a small feedforward network connected using residual connection



The encoder block is pretty much straight forward like in original vision transformer paper.

In [None]:
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # The dimension of each individual attention head
        self.head_dim = embed_dim // num_heads

        # Ensure the embedding dimension can be evenly divided by the number of heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"

        # The scaling factor (1 / sqrt(d_k)) to prevent dot products from getting too large
        self.scale = self.head_dim ** -0.5

        # A single linear layer to calculate Q, K, and V simultaneously (3 * embed_dim)
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.attn_drop = nn.Dropout(dropout)

        # The final linear projection to mix the concatenated heads back together
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        # B = Batch Size
        # N = Sequence Length (Number of patches + 1 for CLS token)
        # C = Embedding Dimension
        B, N, C = x.shape

        # 1. Project input to Q, K, V
        # Resulting shape: (B, N, 3 * C)
        qkv = self.qkv(x)

        # 2. Reshape and Permute to separate the heads
        # Step A: Reshape to (B, N, 3, num_heads, head_dim)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim)

        # Step B: Permute to (3, B, num_heads, N, head_dim)
        # We bring the '3' to the front to easily unpack Q, K, and V
        qkv = qkv.permute(2, 0, 3, 1, 4)

        # Unpack the tensor into our Query, Key, and Value tensors
        # Each has shape: (B, num_heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 3. Calculate Attention Scores (Q * K^T)
        # We transpose the last two dimensions of K to do the matrix multiplication
        # Resulting shape: (B, num_heads, N, N)
        attn = (q @ k.transpose(-2, -1)) * self.scale

        # 4. Apply Softmax to get attention probabilities
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 5. Multiply Attention Scores by Values (attn * V)
        # Resulting shape: (B, num_heads, N, head_dim)
        x = attn @ v

        # 6. Concatenate the heads back together
        # We transpose back to (B, N, num_heads, head_dim) and then flatten the last two dims
        # Resulting shape: (B, N, C)
        x = x.transpose(1, 2).reshape(B, N, C)

        # 7. Final linear projection and dropout
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

In [None]:
class EncoderBlock(nn.Module):

  # mlp_ratio = 4, standard in transformers
  # mlp is usuallu 4Xlarger thah the input layer
  def __init__(self, embed_dim = EMBED_DIM,
               num_heads = NUM_HEADS,
               mlp_ratio = 4.0,
               dropout = 0.0):
    super().__init__()

    self.norm1 = nn.LayerNorm(embed_dim)
    self.attn = MultiHeadAttention(embed_dim, num_heads, dropout=dropout)


    self.norm2 = nn.LayerNorm(embed_dim)

    hidden_features = int(embed_dim * mlp_ratio)

    self.mlp = nn.Sequential(
        nn.Linear(embed_dim, hidden_features),
        nn.GELU(),
        nn.Dropout(dropout),
        nn.Linear(hidden_features, embed_dim),
        nn.Dropout(dropout)
    )


  def forward(self, x):

    attn_input = self.norm1(x)

    attn_output = self.attn(attn_input)
    x = x + attn_output

    mlp_input = self.norm2(x)

    mlp_output = self.mlp(mlp_input)

    x = x + mlp_output

    return x




### Vision Transformer model

Now we assemble all the classes.

In [None]:
class VisionTransformer(nn.Module):
  def __init__(self,
               img_size = IMG_SIZE,
               patch_size = PATCH_SIZE,
               num_classes = NUM_CLASSES,
               embed_dim = EMBED_DIM,
               depth = DEPTH,
               num_heads = NUM_HEADS,
               in_channels =  IN_CHANNELS,
               dropout = 0.1):

    super().__init__()

    self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)

    num_patches = self.patch_embedding.num_patches

    self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
    self.pos_embedding = nn.Parameter(torch.zeros(1, 1 + num_patches, embed_dim))
    nn.init.trunc_normal_(self.pos_embedding, std=0.02)


    self.dropout = nn.Dropout(dropout)


    self.encoder = nn.Sequential(*[EncoderBlock(embed_dim, num_heads, dropout=dropout) for _ in range(depth)])


    self.norm = nn.LayerNorm(embed_dim)
    self.mlp_head = nn.Sequential(
        nn.Linear(embed_dim, num_classes)
    )


  def forward(self, x):

        # Extract patches and project
        x = self.patch_embedding(x)
        B = x.size(0)

        # Prepend the CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings
        x = x + self.pos_embedding

        # Apply dropout after positional embedding
        x = self.dropout(x)

        # 6. Streamlined Encoder pass
        x = self.encoder(x)

        # Extract the CLS token output (the 0th index)
        x = x[:, 0]

        # Apply the final norm and pass through the classification head
        x = self.norm(x)
        out = self.mlp_head(x)

        return out


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt

In [None]:

# --- DATA PREP ---
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # CIFAR-10 standard normalization
])

test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 170M/170M [00:03<00:00, 44.0MB/s]


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:

def train_and_get_acc(embed_dim, num_heads, depth, epochs):
    print(f"Training: Dim={embed_dim}, Heads={num_heads}, Depth={depth}")

    # Instantiate VisionTransformer with the provided parameters
    model = VisionTransformer(embed_dim=embed_dim, num_heads=num_heads, depth=depth).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    acc_history = []

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct_epoch = 0
        total_epoch = 0
        print(f"\nEpoch {epoch+1}")

        for batch_idx, (images, labels) in enumerate(trainloader):
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct = (preds == labels).sum().item()

            correct_epoch += correct
            total_epoch += labels.size(0)


        # End of Epoch Training Summary
        epoch_train_acc = 100.0 * correct_epoch / total_epoch
        print(f"==> Epoch {epoch+1} Summary: Total Loss = {total_loss:.4f}, Training Accuracy = {epoch_train_acc:.2f}%")

        # Validation after each epoch
        model.eval()
        correct_test = 0
        total_test = 0

        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                preds = outputs.argmax(dim=1)
                correct_test += (preds == labels).sum().item()
                total_test += labels.size(0)

        test_acc = 100.0 * correct_test / total_test
        acc_history.append(test_acc)
        print(f"==> Test Accuracy after Epoch {epoch+1}: {test_acc:.2f}%")

    return acc_history

In [None]:
# --- RUN EXPERIMENTS ---
# 1. Val Accuracy vs Epochs
acc_vs_epoch = train_and_get_acc(embed_dim=64, num_heads=4, depth=4, epochs=50)

# 2. Val Accuracy vs Heads (Fixed Epochs=15, Dim=64)
#heads_list = [2,8,16]
#acc_vs_heads = [train_and_get_acc(64, h, h, 15)[-1] for h in heads_list]

# 3. Val Accuracy vs Embed Dim (Fixed Epochs=15, Heads=4)
#dims_list = [16, 32, 128]
#acc_vs_dims = [train_and_get_acc(d, 4, 4, 15)[-1] for d in dims_list]

Training: Dim=64, Heads=4, Depth=4

Epoch 1
==> Epoch 1 Summary: Total Loss = 1456.4128, Training Accuracy = 29.80%
==> Test Accuracy after Epoch 1: 34.15%

Epoch 2
==> Epoch 2 Summary: Total Loss = 1318.9920, Training Accuracy = 37.25%
==> Test Accuracy after Epoch 2: 39.37%

Epoch 3
==> Epoch 3 Summary: Total Loss = 1241.0304, Training Accuracy = 41.40%
==> Test Accuracy after Epoch 3: 45.11%

Epoch 4
==> Epoch 4 Summary: Total Loss = 1187.3029, Training Accuracy = 44.35%
==> Test Accuracy after Epoch 4: 47.38%

Epoch 5
==> Epoch 5 Summary: Total Loss = 1146.4641, Training Accuracy = 46.46%
==> Test Accuracy after Epoch 5: 49.31%

Epoch 6
==> Epoch 6 Summary: Total Loss = 1113.6682, Training Accuracy = 47.82%
==> Test Accuracy after Epoch 6: 50.79%

Epoch 7
==> Epoch 7 Summary: Total Loss = 1079.9252, Training Accuracy = 49.95%
==> Test Accuracy after Epoch 7: 52.89%

Epoch 8
==> Epoch 8 Summary: Total Loss = 1051.5142, Training Accuracy = 51.18%
==> Test Accuracy after Epoch 8: 54.1

In [None]:
# Got around 69% accuracy for the following parameters : embed_dim=64, num_heads=4, depth=4, epochs=50
# Play with uncommenting the Val_accuray vs heads, val_accuracy vs embed dim to find the optimal parameters.