# Building a Vision Transformer (ViT) from Scratch

In this tutorial, you'll implement each component of the Vision Transformer step by step. Each module is provided as a stub with descriptions to guide your implementation.

## Setup (Just run)

In [8]:
import sys
from pathlib import Path
from dataclasses import dataclass
import pytest
import torch as t
import torch.nn as nn
from torch import Tensor
from jaxtyping import Float
import functools
import nanoclip.vit as nanoclip_vit


def replace_nanoclip_implementation(cls):
    @functools.wraps(cls)
    def wrapper(*args, **kwargs):
        return cls(*args, **kwargs)

    setattr(nanoclip_vit, cls.__name__, cls)
    return wrapper

def remove_test_module(module_name):
    modules_to_remove = [name for name in sys.modules if name.startswith(module_name)]
    for module_name in modules_to_remove:
        del sys.modules[module_name]


def run_test(test_name, test_func):
    remove_test_module("test_vit_clip")
    project_root = Path.cwd().resolve().parent
    setattr(nanoclip_vit, test_func.__name__, test_func)
    pytest.main(["-v", "-p", "no:cacheprovider", "-s", f"{project_root}/tests/test_vit_clip.py::{test_name}"])


# Bit of a hack, ViTConfig doesn't seem to be able to be overwritted
def test_vit_config(vit_config_cls):
    config = vit_config_cls(
        n_layers=12,
        d_model=768,
        d_proj=512,
        image_res=(224, 224),
        patch_size=(16, 16),
        n_heads=12,
        norm_data=((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    )
    assert config.num_patches == (14, 14)
    assert config.seq_length == 197  # 14 * 14 + 1 (CLS token)
    assert config.d_head == 64  # 768 // 12

## 1. ViT Configuration

In [None]:
@dataclass
class ViTConfig:
    n_layers: int
    d_model: int
    d_proj: int
    image_res: tuple[int, int]
    patch_size: tuple[int, int]
    n_heads: int
    norm_data: tuple[
        tuple[float, float, float], tuple[float, float, float]
    ]  # (mean, std)

    mlp_mult: int = 4
    causal_attn: bool = False

    # Calculated in __post_init__
    d_head: int = None  # type: ignore
    num_patches: tuple[int, int] = None  # type: ignore
    seq_length: int = None  # type: ignore

    def __post_init__(self):
        """
        Compute the attributes based on initial config:

        1. num_patches: Number of patches in based on the height & width and patch_size.
        2. seq_length: number_of_patches + the CLS token
        3. d_head: Dimension of each head. Use d_model & n_heads
        """
        raise NotImplementedError()

# run_test("test_vit_config", ViTConfig)
test_vit_config(ViTConfig)

## 2. Patch Embeddings

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, cfg: ViTConfig):
        super().__init__()
        self.cfg = cfg
        self.class_embedding = nn.Parameter(t.empty(cfg.d_model))
        self.patch_embedding = nn.Parameter(
            t.empty(cfg.d_model, 3 * cfg.patch_size[0] * cfg.patch_size[1])
        )
        self.position_embedding = nn.Parameter(t.empty(cfg.seq_length, cfg.d_model))

    def forward(
        self, pixel_values: Float[Tensor, "batch 3 height width"]
    ) -> Float[Tensor, "batch seq d_model"]:
        """
        Forward pass for PatchEmbeddings:
        1. Rearrange pixel_values into a sequence of patches:
          - Hint: Should be of shape (batch, num_patches, channels * patch_size)
        2. Project the patches with patch_embedding
        3. Add the class embedding
        4. Add the position embeddings
        """
        # Implementation steps here (refer to lines 70-96 in nanoclip/vit.py)
        raise NotImplementedError()

run_test("test_patch_embedding", PatchEmbedding)

## 3. Attention

ViTs use **full attention**, you _should not_ use a causal mask.

Many different ways of implementing attention -- I tend to like a very explicit einops operation.

#### Q: Why would this be the case

> Answer: We want each patch to fully attend to all other patches. The final output is the (first) CLS token.


In [None]:
class Attention(nn.Module):
    def __init__(self, cfg: ViTConfig):
        """
        Initialize the Attention module:
        1. Create query, key, and value projection layers
        2. Create output projection layer
        """
        super().__init__()
        self.cfg = cfg
        self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True)
        self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True)
        self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True)
        self.out_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=True)

    def forward(
        self, x: Float[Tensor, "batch seq d_model"]
    ) -> Float[Tensor, "batch seq d_model"]:
        """
        Forward pass for Attention:
        1. Project input to query, key, and value
        2. Reshape q, k, v to separate the heads
        3. Compute (full) attention scores
        4. Apply softmax to get attention weights
        5. Apply attention weights to values
        6. Reshape and project to output
        """
        # (Solution: 109-141 nanoclip/vit.py)
        raise NotImplementedError()

run_test("test_attn", Attention)


## 4. MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: ViTConfig):
        """
        Initialize the MLP module:
        1. Create up-projection layer
        2. Create down-projection layer
        """
        super().__init__()
        self.cfg = cfg
        self.up_proj = nn.Linear(cfg.d_model, cfg.d_model * cfg.mlp_mult, bias=True)
        self.down_proj = nn.Linear(cfg.d_model * cfg.mlp_mult, cfg.d_model, bias=True)

    def forward(
        self, x: Float[Tensor, "batch seq d_model"]
    ) -> Float[Tensor, "batch seq d_model"]:
        """
        Forward pass for MLP:
        1. Apply up-projection
        2. Apply activation function (quick_gelu)
        3. Apply down-projection
        """
        # Implementation steps here (refer to lines 151-155 in nanoclip/vit.py)
        raise NotImplementedError()

run_test("test_mlp", MLP)

## 5. Transformer Block

Now we're putting it all together.

In [None]:
@replace_nanoclip_implementation
class TransformerBlock(nn.Module):
    def __init__(self, cfg: ViTConfig):
        super().__init__()
        self.cfg = cfg
        self.attn = Attention(cfg)
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.mlp = MLP(cfg)
        self.ln2 = nn.LayerNorm(cfg.d_model)

    def forward(
        self, x: Float[Tensor, "batch seq d_model"]
    ) -> Float[Tensor, "batch seq d_model"]:
        """
        1. Apply layer norm -> attention to the input, add to residual
        2. Apply layer norm -> MLP, add to residual
        """
        # Implementation steps here (refer to lines 167-170 in nanoclip/vit.py)
        raise NotImplementedError()

run_test("test_xfmer_block", TransformerBlock)

## 6. Assembling the Vision Transformer (ViT) Model


In [None]:
@replace_nanoclip_implementation
class ViT(nn.Module):
    def __init__(self, cfg: ViTConfig):
        super().__init__()
        self.cfg = cfg
        self.embed = PatchEmbedding(cfg)
        self.pre_ln = nn.LayerNorm(cfg.d_model)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.post_ln = nn.LayerNorm(cfg.d_model)
        self.out_proj = nn.Linear(cfg.d_model, cfg.d_proj, bias=False)

    def forward(
        self, pixel_values: Float[Tensor, "batch 3 height width"]
    ) -> Float[Tensor, "batch d_proj"]:
        """
        1. Apply patch embeddings
        2. Apply pre-layer normalization
        3. Apply transformer blocks
        4. Select CLS token
        5. Apply post-layer normalization
        6. Apply output projection
        """
        # Implementation steps here (refer to lines 183-192 in nanoclip/vit.py)
        raise NotImplementedError()

run_test("test_e2e", ViT)


## Conclusion

You've set up the structure for implementing each component of the Vision Transformer. Complete each method by replacing the `NotImplementedError` with the appropriate code as guided by the docstrings. Once all components are implemented, you can proceed to integrate and test the ViT model.