```{contents}
```

## Vision Transformers (ViT)

A **Vision Transformer (ViT)** is a neural network that applies **Transformer architecture**—originally designed for NLP—to **image understanding**.

Instead of using convolutional layers like CNNs, ViTs process an image as a **sequence of patches**, similar to a sequence of words in a sentence.

---

###  WHY ViTs WERE INVENTED

CNNs (ResNet, VGG, MobileNet) dominated vision because they:

* detect local patterns (edges, textures)
* use convolutions with fixed receptive fields

But CNNs struggle with:

* capturing **long-range dependencies**
* global context
* scalability

Transformers naturally capture **global relationships** with **self-attention**.

So researchers asked:
**What if we treat an image like a text sequence?**

This became the Vision Transformer.

---

### HOW VISION TRANSFORMERS WORK (Step-by-Step)

---

#### **Split the Image into Patches**

Example:
Image size: **224 × 224**
Patch size: **16 × 16**

→ Total patches = (224/16)² = **196 patches**
Each patch = flattened into a vector.

Think of patches as "visual tokens."

---

#### **Convert Each Patch to an Embedding (Linear Layer)**

For each patch:

1. Flatten: 16×16×3 → 768 numbers
2. Linear projection:

   ```
   patch_embedding = W * flattened_patch
   ```

Now each patch becomes a vector in a high-dimensional space (like word embeddings).

---

#### **Add Positional Embeddings**

Transformers don’t know the order of tokens.
So we add **positional embeddings** to encode patch location:

```
patch_embedding + position_embedding
```

This tells the model:

* where each patch came from
* how patches relate spatially

---

#### **Feed Patch Embeddings into Transformer Encoder**

Exactly like BERT:

* Multi-head Self-Attention
* LayerNorm
* Feed Forward Network
* Residual Connections

Self-attention allows each patch to "look at" every other patch:

> “Which other parts of the image matter to understanding this patch?”

This gives ViTs **global receptive fields** from the beginning.

---

#### **Use CLS Token for Classification (like BERT)**

A special learnable vector **[CLS]** is prepended to all patches.

After processing, the corresponding output embedding encodes the entire image.

Finally, pass CLS output → MLP → label prediction.

---

### VIOLENTLY SIMPLE ARCHITECTURE

A Vision Transformer has:

```
Patch Embedding
+ Transformer Encoder Layers
+ MLP Classifier
```

No convolutions
No pooling
No feature maps

---

### WHY VISION TRANSFORMERS WORK WELL

#### ✔ 1. **Global Context From the Start**

Self-attention lets ViTs focus on:

* long-range dependencies
* relationships between objects
* global structure

CNNs only gain wide context in deeper layers.

---

#### ✔ 2. **Scales Extremely Well**

The more data you give ViTs, the better they get.
With large datasets (ImageNet-21k, JFT-300M), ViTs **outperform CNNs**.

---

#### ✔ 3. **Uniform Architecture**

Same Transformer blocks for:

* images
* text
* multimodal models (CLIP, Flamingo, Gemini)

This makes ViTs ideal for **multimodal AI**.

---

**LIMITATIONS OF ViTs**

| Issue                                       | Why                      |
| ------------------------------------------- | ------------------------ |
| **Needs lots of data**                      | Lacks CNN inductive bias |
| **Patch-level processing may miss details** | Especially small objects |
| **High compute cost**                       | Attention is O(N²)       |

Hybrid models (ConvNeXt, Swin Transformer) fix these.

---

**TYPES OF ViTs**

| Model                    | Idea                                  |
| ------------------------ | ------------------------------------- |
| **ViT**                  | Basic transformer for images          |
| **DeiT**                 | Data-efficient ViT (less data needed) |
| **Swin Transformer**     | Hierarchical ViT with local windows   |
| **ViT-Huge / ViT-Giant** | Massive models for SOTA               |

---

**A SIMPLE INTUITION**

> ViT slices an image into small squares, treats each square like a word, and uses attention to understand how all squares relate to each other.



In [1]:
import torch
import torch.nn as nn

# ----------------------------------------------------
# 1. PATCH EMBEDDING
# ----------------------------------------------------
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        
        # Create patch embedding using a Conv2d layer
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size  # non-overlapping patches
        )

    def forward(self, x):
        # x shape: (B, 3, H, W)
        x = self.proj(x)         # (B, embed_dim, H/patch, W/patch)
        x = x.flatten(2)         # flatten: (B, embed_dim, N_patches)
        x = x.transpose(1, 2)    # (B, N_patches, embed_dim)
        return x


# ----------------------------------------------------
# 2. MULTI-HEAD SELF ATTENTION
# ----------------------------------------------------
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super().__init__()
        self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x):
        # Self-attention: Q = K = V = x
        attn_output, _ = self.mha(x, x, x)
        return attn_output


# ----------------------------------------------------
# 3. TRANSFORMER ENCODER BLOCK
# ----------------------------------------------------
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        
        self.norm2 = nn.LayerNorm(embed_dim)
        
        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, embed_dim),
        )

    def forward(self, x):
        # Attention with residual
        x = x + self.attn(self.norm1(x))
        
        # MLP with residual
        x = x + self.mlp(self.norm2(x))
        return x


# ----------------------------------------------------
# 4. COMPLETE VISION TRANSFORMER (ViT)
# ----------------------------------------------------
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_channels=3,
        num_classes=10,
        embed_dim=768,
        depth=6,
        num_heads=12,
    ):
        super().__init__()

        # Patch Embedding
        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_channels, embed_dim
        )

        num_patches = (img_size // patch_size) ** 2

        # CLS token (learnable)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Positional Embedding (learnable)
        self.pos_embed = nn.Parameter(
            torch.zeros(1, num_patches + 1, embed_dim)
        )

        # Transformer Encoder Layers
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads)
            for _ in range(depth)
        ])

        # Final classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)  # (B, N_patches, embed_dim)

        # Add CLS token at position 0
        cls = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls, x), dim=1)

        # Add positional embedding
        x = x + self.pos_embed

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x)

        # Use CLS token output for classification
        cls_output = self.norm(x[:, 0])
        return self.fc(cls_output)


# ----------------------------------------------------
# 5. TEST THE MODEL
# ----------------------------------------------------
if __name__ == "__main__":
    model = VisionTransformer()
    img = torch.randn(2, 3, 224, 224)  # batch of 2 images
    out = model(img)

    print("Output shape:", out.shape)  # (2, num_classes)


Output shape: torch.Size([2, 10])
