# Complete Transformer Implementation: From Attention to Generation

## 🎯 Introduction

Welcome to the pinnacle of modern deep learning architecture! This notebook will take you from understanding individual components to implementing a complete, working transformer from scratch. You'll build the same architecture that powers GPT, BERT, and T5 - the foundation of the AI revolution.

### 🧠 What You'll Build

This comprehensive implementation covers:
- **Multi-head self-attention**: The core innovation that revolutionized AI
- **Position encoding**: How transformers understand sequence order without recurrence
- **Feed-forward networks**: The computation powerhouse within each layer
- **Layer normalization**: Stabilizing training in deep networks
- **Complete transformer blocks**: Putting all components together
- **Text generation**: Making your transformer actually generate coherent text

### 🎓 Prerequisites

- Solid understanding of PyTorch tensors, modules, and training loops
- Familiarity with neural network architectures (MLPs, basic concepts)
- Basic knowledge of attention mechanisms and sequence modeling
- Understanding of matrix operations and linear algebra

### 🚀 Why Transformers Changed Everything

Transformers solved fundamental limitations of previous architectures:
- **Parallelization**: Unlike RNNs, all positions computed simultaneously
- **Long-range dependencies**: Attention directly connects distant positions
- **Scalability**: Architecture scales efficiently to massive models
- **Transfer learning**: Pre-trained models work across many tasks
- **Interpretability**: Attention weights show what the model focuses on

---

## 📚 Table of Contents

1. **[Attention Mechanism Deep Dive](#attention-mechanism-deep-dive)** - The heart of transformers
2. **[Position Encoding & Embeddings](#position-encoding-embeddings)** - Giving transformers spatial awareness
3. **[Transformer Block Architecture](#transformer-block-architecture)** - Complete encoder/decoder components
4. **[Full Model Implementation](#full-model-implementation)** - Putting everything together
5. **[Text Generation & Inference](#text-generation-inference)** - Making your transformer talk

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import math

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Attention Mechanism Deep Dive

### 🔍 The Revolutionary Breakthrough

Self-attention is the core innovation that made transformers possible. It allows every position in a sequence to directly attend to every other position, solving the fundamental limitation of RNNs that processed sequences step by step.

In [None]:
# =============================================================================
# SELF-ATTENTION: THE FOUNDATION OF TRANSFORMERS
# =============================================================================

print("🔍 Self-Attention Mechanism")
print("=" * 50)

class SelfAttention(nn.Module):
    """
    Single-head self-attention mechanism.
    
    The core idea: For each position, compute how much to attend to 
    every other position, then aggregate information accordingly.
    """
    
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        
        # Linear projections for Q, K, V from the same input
        # This is what makes it "self"-attention
        self.query_proj = nn.Linear(d_model, d_model, bias=False)
        self.key_proj = nn.Linear(d_model, d_model, bias=False)
        self.value_proj = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection to combine attention results
        self.output_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # For numerical stability
        self.scale = 1.0 / math.sqrt(d_model)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor [batch_size, seq_len, d_model]
            mask: Optional attention mask [batch_size, seq_len, seq_len]
        
        Returns:
            output: Attention output [batch_size, seq_len, d_model]
            attention_weights: [batch_size, seq_len, seq_len]
        """
        batch_size, seq_len, d_model = x.shape
        
        # Step 1: Project input to Query, Key, Value
        # All come from the same input x - that's why it's "self"-attention
        Q = self.query_proj(x)  # [batch, seq_len, d_model]
        K = self.key_proj(x)    # [batch, seq_len, d_model] 
        V = self.value_proj(x)  # [batch, seq_len, d_model]
        
        print(f"Input shape: {x.shape}")
        print(f"Q, K, V shapes: {Q.shape}")
        
        # Step 2: Compute attention scores
        # This determines how much each position attends to every other position
        scores = torch.matmul(Q, K.transpose(-2, -1))  # [batch, seq_len, seq_len]
        scores = scores * self.scale  # Scale to prevent extreme softmax values
        
        print(f"Attention scores shape: {scores.shape}")
        print(f"Attention scores represent: each position's affinity to every other position")
        
        # Step 3: Apply mask if provided (for causal/padding masking)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))\n        
        # Step 4: Convert scores to probabilities
        attention_weights = torch.softmax(scores, dim=-1)  # [batch, seq_len, seq_len]
        attention_weights = self.dropout(attention_weights)
        
        print(f"Attention weights shape: {attention_weights.shape}")
        print(f"Each row sums to 1.0 (probability distribution over positions)")
        
        # Step 5: Apply attention to values
        # This is where the actual "information mixing" happens
        output = torch.matmul(attention_weights, V)  # [batch, seq_len, d_model]
        
        # Step 6: Final output projection
        output = self.output_proj(output)
        
        print(f"Final output shape: {output.shape}")
        print(f"✓ Each position is now a weighted combination of all positions")
        
        return output, attention_weights

# Demonstrate self-attention with a simple example
print(f"\n🎯 Self-Attention in Action")
print("=" * 50)

# Create a simple 3-token sequence
batch_size, seq_len, d_model = 1, 3, 8
x = torch.randn(batch_size, seq_len, d_model)

# Create attention module
attention = SelfAttention(d_model)

# Apply self-attention
output, weights = attention(x)

print(f"\nAttention weights analysis:")
print(f"Shape: {weights.shape}")
print(f"Weights matrix (how much each position attends to others):")

# Make weights more readable
weights_2d = weights[0].detach()  # Remove batch dimension
print("     Pos0   Pos1   Pos2")
for i in range(seq_len):
    row_str = f"Pos{i}: "
    for j in range(seq_len):
        row_str += f"{weights_2d[i, j].item():.3f}  "
    print(row_str)

print(f"\nInterpretation:")
print(f"• Each row shows how much that position attends to all positions")
print(f"• Diagonal elements: how much each position attends to itself")
print(f"• Off-diagonal: how much positions attend to each other")
print(f"• All rows sum to 1.0 (probability distributions)")

## Using PyTorch's Built-in Transformer

In [None]:
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self.create_positional_encoding(max_seq_len, d_model)
        
        # PyTorch's built-in transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=0.1,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        
        self.output_projection = nn.Linear(d_model, vocab_size)
        
    def create_positional_encoding(self, max_seq_len, d_model):
        pe = torch.zeros(max_seq_len, d_model)
        position = torch.arange(0, max_seq_len).unsqueeze(1).float()
        
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                           -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        return pe.unsqueeze(0)  # Add batch dimension
        
    def forward(self, x, mask=None):
        seq_len = x.size(1)
        
        # Embedding and positional encoding
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = x + self.pos_encoding[:, :seq_len, :].to(x.device)
        
        # Transformer layers
        x = self.transformer(x, src_key_padding_mask=mask)
        
        # Output projection
        return self.output_projection(x)

# Create a simple transformer
vocab_size = 1000
d_model = 128
n_heads = 8
n_layers = 4
d_ff = 512
max_seq_len = 100

model = SimpleTransformer(vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len)

# Test with random input
batch_size = 4
seq_len = 20
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))

output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## Masking: Causal and Padding

In [None]:
def create_causal_mask(seq_len):
    """Create a causal (lower triangular) mask"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask

def create_padding_mask(sequences, pad_token_id=0):
    """Create a padding mask"""
    return (sequences != pad_token_id)

# Example of causal masking for autoregressive generation
seq_len = 5
causal_mask = create_causal_mask(seq_len)
print("Causal mask:")
print(causal_mask.numpy())

# Example of padding mask
sequences = torch.tensor([
    [1, 2, 3, 4, 5],    # No padding
    [1, 2, 3, 0, 0],    # Padded with 0s
    [1, 0, 0, 0, 0]     # Heavily padded
])
padding_mask = create_padding_mask(sequences)
print("\nSequences:")
print(sequences.numpy())
print("Padding mask (True = real token, False = padding):")
print(padding_mask.numpy())

# Using masks in transformer
class MaskedTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=n_heads,
            batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, n_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, tgt, memory, tgt_mask=None, tgt_key_padding_mask=None):
        tgt = self.embedding(tgt)
        output = self.transformer(
            tgt, memory, 
            tgt_mask=tgt_mask, 
            tgt_key_padding_mask=tgt_key_padding_mask
        )
        return self.output_proj(output)

print("\nMasked transformer created for autoregressive generation")

## Next-Token Prediction Demo

In [None]:
# Simple character-level next-token prediction
class CharacterLevelTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=64, n_heads=4, n_layers=2):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_model * 4,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, n_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, mask=None):
        x = self.embedding(x) * math.sqrt(self.d_model)
        x = self.transformer(x, mask=mask)
        return self.output_proj(x)

# Create toy dataset
text = "hello world this is a simple example"
chars = sorted(list(set(text)))
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)

print(f"Vocabulary: {chars}")
print(f"Vocab size: {vocab_size}")

# Convert text to indices
encoded = [char_to_idx[ch] for ch in text]
print(f"Encoded text: {encoded[:10]}...")  # First 10 tokens

# Create training data (input: n chars, target: next char)
def create_sequences(encoded_text, seq_len):
    inputs, targets = [], []
    for i in range(len(encoded_text) - seq_len):
        inputs.append(encoded_text[i:i+seq_len])
        targets.append(encoded_text[i+1:i+seq_len+1])  # Shifted by 1
    return torch.tensor(inputs), torch.tensor(targets)

seq_len = 8
inputs, targets = create_sequences(encoded, seq_len)
print(f"\nTraining data shapes: {inputs.shape}, {targets.shape}")
print(f"Example input: {inputs[0]} -> target: {targets[0]}")

# Create and train model
model = CharacterLevelTransformer(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop (simplified)
model.train()
for epoch in range(100):
    optimizer.zero_grad()
    
    # Forward pass
    output = model(inputs)  # Shape: (batch, seq_len, vocab_size)
    
    # Reshape for loss calculation
    loss = criterion(output.view(-1, vocab_size), targets.view(-1))
    
    # Backward pass
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Generate text
def generate_text(model, start_text, max_len=20, temperature=1.0):
    model.eval()
    with torch.no_grad():
        # Convert start text to indices
        current = [char_to_idx[ch] for ch in start_text]
        generated = current.copy()
        
        for _ in range(max_len):
            # Take last seq_len characters
            input_seq = torch.tensor(current[-seq_len:]).unsqueeze(0)
            
            # Get predictions
            output = model(input_seq)
            
            # Get next token probabilities
            next_token_logits = output[0, -1, :] / temperature
            probs = F.softmax(next_token_logits, dim=-1)
            
            # Sample next token
            next_token = torch.multinomial(probs, 1).item()
            
            current.append(next_token)
            generated.append(next_token)
            
            # Stop if we generate the same character many times (simple stop condition)
            if len(set(generated[-5:])) == 1:
                break
        
        return ''.join([idx_to_char[idx] for idx in generated])

# Generate some text
start = "hello"
generated = generate_text(model, start, max_len=15)
print(f"\nGenerated text starting with '{start}': '{generated}'")

print("\nNext-token prediction demo completed!")