# 6. The Complete Transformer

**Assembling all components into a working language model**

We now assemble all our components into a working decoder-only transformer (GPT-style). This is a complete language model that can be trained to predict the next word in a sequence.

**What is "decoder-only"?** The original transformer paper had both an encoder (for reading input) and decoder (for generating output), used for translation. Modern language models like GPT use only the decoder part, which is simpler and works great for text generation. The key difference is that decoder-only models use causal masking—they can only look at previous tokens, not future ones.

## How Data Flows Through the Model

1. **Token Embedding:** Convert input token IDs (integers) to dense vectors

2. **Positional Encoding:** Add position information to tell the model where each token is

3. **Transformer Blocks (×N):** Stack multiple identical blocks (we use 6; GPT-3 uses 96). Each block refines the representations through attention and feed-forward processing

4. **Final LayerNorm:** One last normalization to stabilize the final outputs

5. **Output Projection:** Project from d_model dimensions to vocabulary size, giving us scores (logits) for every possible next token

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Components from previous notebooks

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
    
    def forward(self, x):
        return self.embedding(x)


class PositionalEncoding(nn.Module):
    """Learned positional embeddings."""
    def __init__(self, d_model, max_seq_len=5000):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        positions = torch.arange(seq_len, device=x.device)
        return x + self.pos_embedding(positions)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        batch_size, seq_len, _ = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 1, float('-inf'))
        
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, V)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        return self.W_o(context)


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.dropout(self.linear2(F.gelu(self.linear1(x))))


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.ffn = FeedForward(d_model, d_ff, dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        x = x + self.dropout1(self.attention(self.norm1(x), mask=mask))
        x = x + self.dropout2(self.ffn(self.norm2(x)))
        return x

In [None]:
class GPTModel(nn.Module):
    """A complete decoder-only transformer language model."""
    
    def __init__(
        self, 
        vocab_size, 
        d_model=512, 
        num_heads=8,
        num_layers=6, 
        d_ff=2048, 
        max_seq_len=5000, 
        dropout=0.1
    ):
        super().__init__()
        
        # Token and positional embeddings
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_len)
        self.dropout = nn.Dropout(dropout)
        
        # Stack of transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        
        # Final layer norm and output projection
        self.ln_f = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)
        
        # Store config
        self.vocab_size = vocab_size
        self.d_model = d_model
    
    def create_causal_mask(self, seq_len):
        """Create mask to prevent attending to future positions."""
        return torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
    
    def forward(self, x, mask=None):
        """
        Args:
            x: (batch, seq_len) - token indices
            mask: optional causal mask
        
        Returns:
            logits: (batch, seq_len, vocab_size)
        """
        # Create causal mask if not provided
        if mask is None:
            mask = self.create_causal_mask(x.size(1)).to(x.device)
        
        # 1. Embed tokens and add positions
        x = self.token_embedding(x)      # (batch, seq) → (batch, seq, d_model)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        
        # 2. Pass through all transformer blocks
        for block in self.blocks:
            x = block(x, mask=mask)
        
        # 3. Final normalization and projection to vocabulary
        x = self.ln_f(x)
        logits = self.output_proj(x)     # (batch, seq, d_model) → (batch, seq, vocab_size)
        
        return logits

In [None]:
# Create a small model for demonstration
model = GPTModel(
    vocab_size=10000,
    d_model=256,
    num_heads=4,
    num_layers=4,
    d_ff=1024,
    max_seq_len=512
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model configuration:")
print(f"  vocab_size: 10,000")
print(f"  d_model: 256")
print(f"  num_heads: 4")
print(f"  num_layers: 4")
print(f"  d_ff: 1,024")
print(f"\nTotal parameters: {total_params:,}")

In [None]:
# Forward pass
batch_size = 2
seq_len = 16

# Random token IDs (simulating tokenized text)
tokens = torch.randint(0, 10000, (batch_size, seq_len))

# Get logits
logits = model(tokens)

print(f"Input tokens: {tokens.shape}")
print(f"Output logits: {logits.shape}")
print(f"\nLogits are scores for each token in vocabulary:")
print(f"  For position 0: {logits[0, 0, :5].detach().numpy().round(2)} ... (10,000 values)")

## What Are Logits?

The model outputs "logits"—raw, unnormalized scores for each token in the vocabulary. Higher scores mean the model thinks that token is more likely to come next.

We can convert these to probabilities using softmax, then either:
- Pick the highest (greedy decoding)
- Sample from the distribution (for more creative generation)

In [None]:
# Convert logits to probabilities
probs = F.softmax(logits[0, -1], dim=-1)  # Last position

# Top 5 most likely next tokens
top_probs, top_indices = torch.topk(probs, 5)

print("Top 5 predicted next tokens (before training):")
for prob, idx in zip(top_probs, top_indices):
    print(f"  Token {idx.item()}: {prob.item():.4f} probability")

print("\n(These are random because the model isn't trained yet!)")

## Training Overview

During training, we feed the model sequences of text and ask it to predict the next token at each position:

1. **Forward pass:** Compute logits for all positions
2. **Loss:** Compare predictions against actual next tokens using cross-entropy loss
3. **Backward pass:** Use backpropagation to compute gradients
4. **Update:** Adjust all weights using an optimizer (like AdamW)

After training on billions of tokens, the model learns to predict plausible next words based on context.

In [None]:
# Simple training step example
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Input: [token0, token1, token2, ...]
# Target: [token1, token2, token3, ...]  (shifted by 1)
input_tokens = tokens[:, :-1]
target_tokens = tokens[:, 1:]

# Forward pass
logits = model(input_tokens)

# Compute loss (cross-entropy)
loss = F.cross_entropy(
    logits.view(-1, model.vocab_size),
    target_tokens.view(-1)
)

print(f"Input shape: {input_tokens.shape}")
print(f"Target shape: {target_tokens.shape}")
print(f"Loss: {loss.item():.4f}")
print(f"\nExpected loss for random predictions: {math.log(10000):.2f}")
print("(Cross-entropy of uniform distribution over 10,000 tokens)")

In [None]:
# Backward pass and update
optimizer.zero_grad()
loss.backward()
optimizer.step()

print("Completed one training step!")
print("Repeat millions of times on real data to train a language model.")

## Model Scale

Our implementation uses 4 layers with d_model=256, suitable for learning. For comparison:

| Model | Layers | d_model | Parameters |
|-------|--------|---------|------------|
| This notebook | 4 | 256 | ~5M |
| GPT-2 Small | 12 | 768 | 117M |
| GPT-2 Large | 36 | 1280 | 774M |
| GPT-3 | 96 | 12288 | 175B |

The architecture scales beautifully—the same fundamental components work at wildly different scales!

## Next: Training at Scale

We'll look at gradient accumulation and validation strategies for stable training on hobby hardware.