## implementing ViT from scratch
Training a vision transformer model on CIFAR-10 dataset

In [1]:
import os
import torch
import torchvision
import numpy as np
import torch.nn as nn

Steps:


*   prepare data
*   prepare patch embeddings from image input
*   add positional embeddings
*   attention head
*   multiple attention heads
*   feed forward network in transformer block
*   transformer block
*   classification head
*   training



# Data preparation

In [6]:
BATCH_SIZE = 16

In [8]:
transformations = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),        # converts to tensor and brings pixel values in the range (0,1)
    torchvision.transforms.Resize((32,32)),
    torchvision.transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))       # brings the values in the range (-1,1) which is desirable for computations in the network
])

train_set = torchvision.datasets.CIFAR10(root=".", train=True, transform=transformations, download=True)

Files already downloaded and verified


In [7]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size = BATCH_SIZE, shuffle=True)

In [9]:
test_set = torchvision.datasets.CIFAR10(root=".", train=False, transform=transformations, download=True)
len(test_set)

Files already downloaded and verified


10000

In [10]:
test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, shuffle=False)

# Embeddings

In [None]:
"""
input image is broken into patches and patch embeddings are made.
This can be implemented using a conv2D layer of kernel size equal to the patch size, stride is also the same.
"""

In [31]:
class PatchEmbeddings(nn.Module):
  def __init__(self, image_channels, patch_size, embedding_size):
    super().__init__()
    self.image_channels = image_channels
    self.patch_size = patch_size
    self.embedding_size = embedding_size
    self.conv = nn.Conv2d(in_channels=self.image_channels,
                          out_channels=self.embedding_size,
                          kernel_size=self.patch_size,
                          stride=self.patch_size)

  def __call__(self, x):
    # x shape --> (B, C, H, H)
    x = self.conv(x)                                  # (B, embedding_size, patches_x, patches_y)
    x = x.flatten(start_dim=2, end_dim=-1)            # (B, embedding_size, num_patches)
    x = x.transpose(-1, -2)                           # (B, num_patches, embedding_size)
    return x

In [41]:
# add positional embeddings and cls token
# positional embeddings will be added in the form of an Embedding layer

class FinalEmbeddings(nn.Module):
  def __init__(self, image_channels, patch_size, num_patches, embedding_size):
    super().__init__()
    self.patch_embeddings = PatchEmbeddings(image_channels, patch_size, embedding_size)
    self.num_patches = num_patches
    self.embedding_size = embedding_size
    self.positional_embedding = nn.Embedding(num_embeddings = self.num_patches + 1,
                                             embedding_dim = self.embedding_size,
                                             dtype = torch.float32)
    self.cls_embedding = nn.Parameter(torch.randn(1, 1, self.embedding_size))

  def forward(self, x):
    x = self.patch_embeddings(x)
    patch_positional_embeddings = self.positional_embedding(torch.arange(1, self.num_patches + 1))  # (num_patches, embedding_size)
    cls_positional_embedding = self.positional_embedding(torch.tensor([0]))                  # (embedding_size)

    # reshaping
    cls_positional_embedding = cls_positional_embedding.view((1, self.embedding_size))
    positional_embeddings = torch.cat([cls_positional_embedding, patch_positional_embeddings], dim=0)   # (num_patches + 1, embedding_size)
    positional_embeddings = positional_embeddings.view((1, self.num_patches + 1, self.embedding_size))

    batch_size = x.size(0)
    cls_embeddings = self.cls_embedding.expand(batch_size, 1, -1)       # expanding cls token to all examples of batch

    # concatenate input patch embeddings and cls_embedding
    x = torch.cat([cls_embeddings, x], dim=1)          # (B, num_patches + 1, embedding_size)

    # adding positional embeddings and embeddings
    x = x + positional_embeddings
    return x

# Multi Head attention block

In [53]:
# a single attention head

class AttentionHead(nn.Module):
  def __init__(self, embedding_size, head_size):
    super().__init__()
    self.head_size = head_size
    self.embedding_size = embedding_size
    self.query = nn.Linear(self.embedding_size, self.head_size, bias=True)
    self.key = nn.Linear(self.embedding_size, self.head_size, bias=True)
    self.value = nn.Linear(self.embedding_size, self.head_size, bias=True)

  def forward(self, x):
    # key, query, value tensors are projected from the input tensor x
    query = self.query(x)       # (B, num_patches + 1, head_size)
    key = self.key(x)
    value = self.value(x)
    # scaled dot product attention weights
    weights = (key @ query.transpose(-1, -2)) / np.sqrt(self.head_size)   # (B, num_patches + 1, num_patches + 1)
    weights = nn.functional.softmax(weights, dim=-1)
    logits = weights @ value                 # (B, num_patches + 1, num_patches + 1) @ (B, num_patches + 1, head_size) ---> (B, num_patches + 1, head_size)
    return logits

In [54]:
# multihead attention block

class MultiHeadAttn(nn.Module):
  def __init__(self, num_heads, embedding_size):
    super().__init__()
    self.num_heads = num_heads
    self.embedding_size = embedding_size
    self.head_size = self.embedding_size//self.num_heads
    self.all_head_size = self.head_size * self.num_heads
    self.ma_head = nn.ModuleList([AttentionHead(self.embedding_size, self.head_size) for _ in range(self.num_heads)])
    self.projection = nn.Linear(self.all_head_size, self.embedding_size)

  def forward(self, x):
    # x --> (B, num_patches + 1, embedding_size)
    attn_output = torch.cat([head(x) for head in self.ma_head], dim=-1)    # (B, num_patches + 1, all_head_size)
    attn_output = self.projection(attn_output)                             # (B, num_patches + 1, embedding_size)
    return attn_output

# Feed forward network in transformer block

In [55]:
# a simple 2-layer MLP

class FeedForward(nn.Module):
  def __init__(self, embedding_size, intermediate_size):
    super().__init__()
    self.embedding_size = embedding_size
    self.intermediate_size = intermediate_size
    self.layer1 = nn.Linear(self.embedding_size, self.intermediate_size)
    self.activation = nn.LeakyReLU()
    self.layer2 = nn.Linear(self.intermediate_size, self.embedding_size)

  def forward(self, x):
    x = self.layer1(x)      # (B, num_patches + 1, intermediate_size)
    x = self.activation(x)
    x = self.layer2(x)      # (B, num_patches + 1, embedding_size)
    return x

# Transformer block

In [56]:
# a single transformer block
"""
a single block will contain:
1. multi head attention layer
2. feed forward network
3. layer normalization and skip connection before and after the multi head attention block
"""

'\na single block will contain:\n1. multi head attention layer\n2. feed forward network\n3. layer normalization and skip connection before and after the multi head attention block\n'

In [57]:
class TransformerBlock(nn.Module):
  def __init__(self, num_heads, embedding_size, intermediate_size):
    super().__init__()
    self.num_heads = num_heads
    self.embedding_size = embedding_size
    self.intermediate_size = intermediate_size

    self.layernorm1 = nn.LayerNorm(self.embedding_size)
    self.mha = MultiHeadAttn(self.num_heads, self.embedding_size)
    self.layernorm2 = nn.LayerNorm(self.embedding_size)
    self.feed_forward = FeedForward(self.embedding_size, self.intermediate_size)

  def forward(self, x):
    x = self.layernorm1(x)
    x = x + self.mha(x)            # (skip connection)
    x = self.layernorm2(x)
    x = x + self.feed_forward(x)   # (skip connection)
    return x

# Encoder

In [58]:
class Encoder(nn.Module):
  def __init__(self, num_layers, num_heads, embedding_size, intermediate_size):
    super().__init__()
    # Create a list of transformer blocks
    self.num_layers = num_layers
    self.num_heads = num_heads
    self.embedding_size = embedding_size
    self.intermediate_size = intermediate_size
    self.blocks = nn.ModuleList([TransformerBlock(self.num_heads,
                                                  self.embedding_size,
                                                  self.intermediate_size) for _ in range(self.num_layers)])

  def forward(self, x):
    for block in self.blocks:
      x = block(x)
    return x

# Model

In [59]:
class SimpleVIT(nn.Module):
  def __init__(self, num_classes, image_channels, patch_size, num_patches, embedding_size, intermediate_size, num_layers, num_heads):
    super().__init__()
    self.num_classes = num_classes
    self.image_channels = image_channels
    self.patch_size = patch_size
    self.num_patches = num_patches
    self.embedding_size = embedding_size
    self.intermediate_size = intermediate_size
    self.num_layers = num_layers
    self.num_heads = num_heads
    self.embeddings = FinalEmbeddings(self.image_channels,
                                      self.patch_size,
                                      self.num_patches,
                                      self.embedding_size)
    self.encoder = Encoder(self.num_layers, self.num_heads, self.embedding_size, self.intermediate_size)
    self.classifier = nn.Linear(self.embedding_size, self.num_classes)

  def forward(self, x):
    embeddings = self.embeddings(x)          # (B,C,H,W) --> (B, num_patches+1, embedding_size)
    encoder_output = self.encoder(embeddings)   # (B, num_patches+1, embedding_size)

    # taking the embedding corresponding to CLS token for classification
    logits = self.classifier(encoder_output[:,0,:])    # (B, 1, num_classes)
    return logits


# Training

In [61]:
EPOCHS = 10
learning_rate = 0.001
num_classes = 10
image_channels = 3
patch_size = 4
image_dim = 32
num_patches = (image_dim // patch_size)**2
embedding_size = 48
intermediate_size = 4*embedding_size
num_layers = 4
num_heads = 4

model = SimpleVIT(num_classes, image_channels, patch_size, num_patches, embedding_size,
                  intermediate_size, num_layers, num_heads)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

In [None]:
train_losses, test_losses, accuracies = [], [], []

for epoch in range(EPOCHS):
  for batch in train_loader:
    x, y = batch[0], batch[1]
    optimizer.zero_grad()
    output_logits = model(x)
    loss = loss_fn(output_logits, y)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())