# Transformer from Scratch

This notebook shows a simplified implementation of a Transformer model from scratch, focusing on the encoder portion and self-attention. We’ll use toy data for demonstration.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        assert embed_size % num_heads == 0, "Embedding size must be divisible by num_heads"
        self.head_dim = embed_size // num_heads

        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        N, seq_length, embed_size = x.shape

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)

        # Reshape Q, K, V for multi-head attention
        Q = Q.view(N, seq_length, self.num_heads, self.head_dim)
        K = K.view(N, seq_length, self.num_heads, self.head_dim)
        V = V.view(N, seq_length, self.num_heads, self.head_dim)

        # Permute for correct matrix multiplication: (N, heads, seq_len, head_dim)
        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)

        # Scaled dot-product attention
        energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
        attention = torch.softmax(energy, dim=-1)

        out = torch.matmul(attention, V)  # (N, heads, seq_length, head_dim)
        out = out.permute(0, 2, 1, 3).contiguous()
        out = out.view(N, seq_length, self.embed_size)

        out = self.fc_out(out)
        return out

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_size, num_heads, forward_expansion, dropout):
        super().__init__()
        self.attention = MultiHeadSelfAttention(embed_size, num_heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion*embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion*embed_size, embed_size)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        attn_out = self.attention(x)
        x = self.norm1(x + self.dropout(attn_out))
        forward_out = self.feed_forward(x)
        x = self.norm2(x + self.dropout(forward_out))
        return x

if __name__ == "__main__":
    # Toy example
    embed_size = 32
    num_heads = 4
    forward_expansion = 4
    dropout = 0.1

    # Create a random input: batch_size=2, seq_len=10
    x = torch.rand((2, 10, embed_size))

    encoder_block = TransformerEncoderBlock(embed_size, num_heads, forward_expansion, dropout)
    out = encoder_block(x)
    print(out.shape)  # Expected: [2, 10, 32]
