In [5]:
## Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np

In [6]:
# Vision Transformer (ViT) Implementation

class PatchEmbedding(nn.Module):
    """Converts an image into a sequence of flattened patches."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)  # (B, embed_dim, H*W) -> (B, H*W, embed_dim)
        return x

class MultiHeadSelfAttention(nn.Module):
    """Multi-Head Self-Attention mechanism for the Transformer Encoder."""
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim, "Embedding size must be divisible by num_heads"

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.fc_out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        queries, keys, values = qkv[0], qkv[1], qkv[2]

        attention = (queries @ keys.transpose(-2, -1)) / np.sqrt(self.head_dim)
        attention = F.softmax(attention, dim=-1)

        out = (attention @ values).transpose(1, 2).reshape(B, N, C)
        out = self.fc_out(out)
        return out, attention

In [7]:
class TransformerEncoderBlock(nn.Module):
    """Single Transformer Encoder Block."""
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, attn_weights = self.attn(x)
        x = self.norm1(x + attn_out)
        mlp_out = self.mlp(x)
        x = self.norm2(x + self.dropout(mlp_out))
        return x, attn_weights

class VisionTransformer(nn.Module):
    """Vision Transformer (ViT) Model."""
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768, num_heads=8, mlp_dim=2048, num_layers=6, num_classes=10):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, (img_size // patch_size) ** 2 + 1, embed_dim))
        self.encoder_layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_dim) for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x += self.pos_embedding

        attention_maps = []
        for layer in self.encoder_layers:
            x, attn_weights = layer(x)
            attention_maps.append(attn_weights)

        x = self.norm(x[:, 0])
        return self.head(x), attention_maps

In [8]:
# Test Example
img_size = 224
patch_size = 16
model = VisionTransformer(img_size=img_size, patch_size=patch_size, num_classes=10)
test_input = torch.randn(1, 3, img_size, img_size)  # Simulated image input
output, attn_maps = model(test_input)
print("Test Output Shape:", output.shape)
print("Attention Map Shape:", attn_maps[0].shape)  # Display first layer attention map shape


Test Output Shape: torch.Size([1, 10])
Attention Map Shape: torch.Size([1, 8, 197, 197])
