# Vision Transformer

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import math

# Patch Embedding Layer

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self,img_size=224,  patch_size=16, in_channels=3,embed_dim=768):
        super(PatchEmbedding, self).__init__()
        
        self.n_patches = (img_size //patch_size)**2
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.projection(x)
        print(f'Projection shape: {x.shape}')
        x = x.flatten(2)
        x = x.transpose(1,2)

        return x


In [3]:
fake_input = torch.randn(1, 3, 224, 224)

patch_embed = PatchEmbedding()
out = patch_embed(fake_input)
print(f'Output shape: {out.shape}')

Projection shape: torch.Size([1, 768, 14, 14])
Output shape: torch.Size([1, 196, 768])


# Transformer Block

**Multi-Head Self-Attention**

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8):
        super(MultiHeadSelfAttention, self).__init__()
        
        # Ensure the embedding dimension is divisible by the number of heads
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # Linear projections for queries, keys, and values
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):

        batch_size, seq_length, embed_dim = x.size()

        # Linear projections and reshape for multi-head attention
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        print(f'Q shape: {q.shape}')

        # Reshape to (batch_size, num_heads, seq_length, head_dim)
        q = q.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1,2)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1,2)

        print(f'Q shape after reshape: {q.shape}')

        # Scaled dot-product attention
        attn_scores = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(self.head_dim)
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_weights, v)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, embed_dim)

         # Final linear projection
        output = self.out_proj(attn_output)

        return output

In [5]:
multi_head_attn = MultiHeadSelfAttention()

out = multi_head_attn(out)

print(f'Output shape after attention: {out.shape}')

Q shape: torch.Size([1, 196, 768])
Q shape after reshape: torch.Size([1, 8, 196, 96])
Output shape after attention: torch.Size([1, 196, 768])


**Feed Forward Network**

In [6]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim=768, expansion_factor=4):
        super(FeedForward, self).__init__()

        self.fc1 = nn.Linear(embed_dim, embed_dim * expansion_factor)
        self.fc2 = nn.Linear(embed_dim * expansion_factor, embed_dim)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

In [7]:
feed_forward = FeedForward()
out = feed_forward(out)

print(f'Output shape after feed-forward: {out.shape}')

Output shape after feed-forward: torch.Size([1, 196, 768])


**Transformer Encoder Block (1 layer)**

In [8]:
class Encoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8):
        super().__init__()
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Residual connection + LayerNorm
        x = x + self.attn(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [9]:
enc = Encoder()
out = enc(out)

print(f'Output shape after encoder: {out.shape}')

Q shape: torch.Size([1, 196, 768])
Q shape after reshape: torch.Size([1, 8, 196, 96])
Output shape after encoder: torch.Size([1, 196, 768])


**Transformer Encoder (12 layers)**

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=768, num_heads=8, depth=12):
        super().__init__()
        self.layers = nn.ModuleList([Encoder(embed_dim, num_heads) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return self.norm(x)


In [11]:
embed_dim = 768   # like BERT-base or ViT-base
num_heads = 12
depth = 12        # 12 transformer blocks
seq_len = 128     # number of tokens
batch_size = 4

x = torch.randn(batch_size, seq_len, embed_dim)

model = TransformerEncoder(embed_dim, num_heads, depth)
output = model(x)

print(output.shape)  # ➤ (4, 128, 768)


Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape after reshape: torch.Size([4, 12, 128, 64])
Q shape: torch.Size([4, 128, 768])
Q shape 