# Chapter 5: Training and Text Generation

Welcome to the final notebook in our LLM from Scratch series! In this chapter, we'll learn how to **train** our GPT model and use it for **text generation**.

## What You'll Learn

1. **Training loop fundamentals**: Gradient descent, backpropagation, optimization
2. **Learning rate scheduling**: Warmup and cosine decay
3. **Gradient accumulation**: Training with larger effective batch sizes
4. **Training best practices**: Checkpointing, validation, monitoring
5. **Text generation strategies**: Greedy, sampling, top-k, top-p, beam search
6. **Advanced generation**: Temperature, repetition penalty, n-gram blocking
7. **Hands-on training**: Train a small GPT model on real data

This is where we see our model come to life!

## 1. Training Fundamentals

Training a language model means teaching it to **predict the next token**.

### The Training Objective:

Given text: "The cat sat on the mat"

```
Input:  [The] [cat] [sat] [on] [the]
Target: [cat] [sat] [on] [the] [mat]
```

At each position, predict the **next token**.

### Loss Function: Cross-Entropy

$$\text{Loss} = -\frac{1}{N}\sum_{i=1}^{N} \log P(\text{target}_i | \text{input}_{<i})$$

- **Lower loss** = better predictions
- **Perplexity** = exp(loss) = "average branching factor"
- Perfect model: loss = 0, perplexity = 1
- Random model: loss ≈ log(vocab_size) ≈ 10.8, perplexity ≈ 50,000

### Training Loop:

```python
for batch in train_data:
    # 1. Forward pass
    logits, loss = model(input_ids, targets=targets)
    
    # 2. Backward pass (compute gradients)
    loss.backward()
    
    # 3. Optimizer step (update weights)
    optimizer.step()
    optimizer.zero_grad()
```

## 2. Learning Rate Scheduling

**Learning rate** controls how much to update weights:
- Too high: training unstable, doesn't converge
- Too low: training slow, gets stuck

### Warmup + Cosine Decay (GPT-2/GPT-3 schedule):

```
Learning Rate
    ↑
max │     ╱╲
    │    ╱  ╲___
    │   ╱       ╲___
    │  ╱            ╲___
min │ ╱                 ╲___
    └─────────────────────────→ Steps
      ↑         ↑
    Warmup    Cosine Decay
```

### Why This Works:

1. **Warmup** (first ~2k steps):
   - Start with low LR (lr/warmup_steps)
   - Gradually increase to max LR
   - Prevents unstable early training

2. **Cosine Decay** (rest of training):
   - Smoothly decrease from max LR to min LR
   - Allows fine-tuning at the end
   - Better final performance

### Formula:

```python
if step < warmup_steps:
    lr = max_lr * (step + 1) / warmup_steps
else:
    progress = (step - warmup_steps) / (max_steps - warmup_steps)
    lr = max_lr * 0.5 * (1 + cos(π * progress))
```

## 3. Gradient Accumulation

**Problem**: Large batch sizes don't fit in GPU memory.

**Solution**: Accumulate gradients over multiple small batches.

### Without Gradient Accumulation:
```python
# Batch size = 32 (may not fit!)
loss = model(batch_32)
loss.backward()
optimizer.step()
```

### With Gradient Accumulation:
```python
# Process 4 batches of size 8 = effective batch size 32
for i in range(4):
    loss = model(batch_8) / 4  # Scale loss
    loss.backward()  # Accumulate gradients

optimizer.step()  # Update once
optimizer.zero_grad()
```

### Benefits:

- ✅ Train with large effective batch sizes
- ✅ Fits in limited GPU memory
- ✅ Better gradient estimates
- ✅ More stable training

## 4. Text Generation Strategies

Once trained, we can generate text in multiple ways:

### 1. Greedy Decoding (Deterministic)
```python
next_token = torch.argmax(logits, dim=-1)  # Always pick most likely
```
- ✅ Fast, deterministic
- ❌ Repetitive, boring

### 2. Random Sampling
```python
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
```
- ✅ Diverse output
- ❌ Can be incoherent

### 3. Temperature Sampling
```python
logits = logits / temperature  # Scale before softmax
probs = torch.softmax(logits, dim=-1)
```
- Temperature = 1.0: Normal sampling
- Temperature < 1.0: More confident (sharper distribution)
- Temperature > 1.0: More random (flatter distribution)

### 4. Top-K Sampling
```python
top_k_logits, top_k_indices = torch.topk(logits, k=50)
# Sample only from top-k most likely tokens
```
- ✅ Prevents sampling unlikely tokens
- ✅ Good balance of quality and diversity

### 5. Top-P (Nucleus) Sampling
```python
# Sample from smallest set with cumulative probability >= p
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum <= p  # Keep tokens until cumsum reaches p
```
- ✅ Adaptive vocabulary size
- ✅ Better than top-k for varying distributions

### 6. Beam Search
- Keep N best hypotheses at each step
- Explore multiple paths simultaneously
- Better for tasks requiring correctness (translation)
- More expensive than sampling

## 5. Hands-On: Training a GPT Model

Let's train a small GPT model on sample data!

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns

from src.llm import (
    GPTModel,
    ModelConfig,
    Trainer,
    TrainingConfig,
    TextGenerator,
    GenerationConfig,
    Tokenizer,
)

torch.manual_seed(42)
print("Imports successful!")

### 5.1 Create Training Data

In [None]:
# Sample training texts (in practice, use much more data!)
train_texts = [
    "The quick brown fox jumps over the lazy dog.",
    "Machine learning is a subset of artificial intelligence.",
    "Transformers revolutionized natural language processing.",
    "GPT models use attention mechanisms for text generation.",
    "Deep learning requires large amounts of training data.",
] * 20  # Repeat for more data

val_texts = [
    "Neural networks learn patterns from data.",
    "Language models predict the next word.",
]

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")

### 5.2 Tokenize and Create DataLoaders

In [None]:
# Create tokenizer
tokenizer = Tokenizer()

# Tokenize
def tokenize_texts(texts, max_length=64):
    """Tokenize texts and create input-target pairs."""
    all_input_ids = []
    all_targets = []
    
    for text in texts:
        tokens = tokenizer.encode(text)
        
        # Skip if too short
        if len(tokens) < 2:
            continue
        
        # Truncate or pad to max_length
        if len(tokens) > max_length:
            tokens = tokens[:max_length]
        
        # For language modeling: input = tokens[:-1], target = tokens[1:]
        # We'll just use same sequence for simplicity
        all_input_ids.append(tokens)
        all_targets.append(tokens)
    
    return all_input_ids, all_targets

train_inputs, train_targets = tokenize_texts(train_texts)
val_inputs, val_targets = tokenize_texts(val_texts)

print(f"Sample input: {train_inputs[0]}")
print(f"Decoded: {tokenizer.decode(train_inputs[0])}")

In [None]:
# Pad sequences to same length
def pad_sequences(sequences, pad_value=0):
    """Pad sequences to same length."""
    max_len = max(len(seq) for seq in sequences)
    padded = []
    for seq in sequences:
        padded.append(seq + [pad_value] * (max_len - len(seq)))
    return torch.tensor(padded)

train_input_ids = pad_sequences(train_inputs)
train_target_ids = pad_sequences(train_targets)
val_input_ids = pad_sequences(val_inputs)
val_target_ids = pad_sequences(val_targets)

print(f"Train input shape: {train_input_ids.shape}")
print(f"Val input shape: {val_input_ids.shape}")

In [None]:
# Create datasets and data loaders
def collate_fn(batch):
    """Custom collate function for DataLoader."""
    input_ids = torch.stack([item[0] for item in batch])
    targets = torch.stack([item[1] for item in batch])
    return {"input_ids": input_ids, "targets": targets}

train_dataset = TensorDataset(train_input_ids, train_target_ids)
val_dataset = TensorDataset(val_input_ids, val_target_ids)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=2, collate_fn=collate_fn)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

### 5.3 Create Model and Optimizer

In [None]:
# Create a tiny model for fast training
model_config = ModelConfig(
    vocab_size=tokenizer.vocab_size,
    max_seq_len=64,
    d_model=128,
    n_layers=4,
    n_heads=4,
    d_ff=512,
    dropout=0.1,
)

model = GPTModel(model_config)
print(f"Model parameters: {model.num_parameters():,}")

# Create optimizer (AdamW is standard for transformers)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

print("Model and optimizer created!")

### 5.4 Training Configuration

In [None]:
# Create training configuration
train_config = TrainingConfig(
    num_epochs=10,
    learning_rate=3e-4,
    warmup_steps=50,
    gradient_accumulation_steps=2,
    max_grad_norm=1.0,
    log_interval=5,
    eval_interval=20,
    save_interval=50,
    device="cpu",  # Use CPU for this demo
)

print("Training configuration:")
print(f"  Epochs: {train_config.num_epochs}")
print(f"  Learning rate: {train_config.learning_rate}")
print(f"  Warmup steps: {train_config.warmup_steps}")
print(f"  Gradient accumulation: {train_config.gradient_accumulation_steps}")

### 5.5 Train the Model

In [None]:
# Create trainer
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    config=train_config,
    train_loader=train_loader,
    val_loader=val_loader,
    checkpoint_dir="../checkpoints"
)

print("Starting training...\n")

# Train
stats = trainer.train()

print("\nTraining complete!")
print(f"Final loss: {stats['final_train_loss']:.4f}")
print(f"Best validation loss: {stats['best_val_loss']:.4f}")

### 5.6 Visualize Training Progress

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Training loss
ax1.plot(trainer.metrics_history['step'], trainer.metrics_history['train_loss'], label='Train Loss')
if trainer.metrics_history['val_loss']:
    val_steps = [trainer.metrics_history['step'][i] for i in range(0, len(trainer.metrics_history['step']), len(trainer.metrics_history['step']) // len(trainer.metrics_history['val_loss']))]
    ax1.plot(val_steps[:len(trainer.metrics_history['val_loss'])], trainer.metrics_history['val_loss'], label='Val Loss', marker='o')
ax1.set_xlabel('Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Learning rate
ax2.plot(trainer.metrics_history['step'], trainer.metrics_history['learning_rate'])
ax2.set_xlabel('Step')
ax2.set_ylabel('Learning Rate')
ax2.set_title('Learning Rate Schedule')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Text Generation with Trained Model

Now let's generate text with our trained model!

### 6.1 Simple Generation (Built-in Method)

In [None]:
# Generate with built-in method
prompt = "Machine learning"
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids])

print(f"Prompt: {prompt}\n")

for temp in [0.5, 1.0, 1.5]:
    generated = model.generate(
        input_tensor,
        max_new_tokens=15,
        temperature=temp,
        top_k=50
    )
    
    text = tokenizer.decode(generated[0].tolist())
    print(f"Temperature {temp}: {text}")
    print()

### 6.2 Advanced Generation (TextGenerator)

In [None]:
# Create text generator with different configurations
configs = [
    {"name": "Greedy", "config": GenerationConfig(max_new_tokens=20, do_sample=False)},
    {"name": "Top-K", "config": GenerationConfig(max_new_tokens=20, do_sample=True, top_k=50, temperature=0.8)},
    {"name": "Top-P", "config": GenerationConfig(max_new_tokens=20, do_sample=True, top_p=0.9, temperature=0.8)},
    {"name": "Beam Search", "config": GenerationConfig(max_new_tokens=20, num_beams=3)},
]

prompt = "The transformer"
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids])

print(f"Prompt: {prompt}\n")

for cfg in configs:
    generator = TextGenerator(model, cfg["config"])
    output = generator.generate(input_tensor)
    text = tokenizer.decode(output[0].tolist())
    print(f"{cfg['name']}: {text}")
    print()

### 6.3 Comparing Sampling Strategies

In [None]:
# Generate multiple samples with same prompt
prompt = "Deep learning"
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids])

config = GenerationConfig(
    max_new_tokens=15,
    do_sample=True,
    temperature=0.8,
    top_k=50,
)

generator = TextGenerator(model, config)

print(f"Prompt: {prompt}\n")
print("Generated variations (sampling):")
for i in range(5):
    output = generator.generate(input_tensor)
    text = tokenizer.decode(output[0].tolist())
    print(f"{i+1}. {text}")

## 7. Advanced: Repetition Penalty

In [None]:
# Without repetition penalty
config_no_penalty = GenerationConfig(
    max_new_tokens=30,
    do_sample=True,
    temperature=0.8,
    top_k=50,
    repetition_penalty=1.0,  # No penalty
)

# With repetition penalty
config_with_penalty = GenerationConfig(
    max_new_tokens=30,
    do_sample=True,
    temperature=0.8,
    top_k=50,
    repetition_penalty=1.2,  # Penalize repeats
)

prompt = "Artificial intelligence"
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids])

print(f"Prompt: {prompt}\n")

print("WITHOUT repetition penalty:")
gen1 = TextGenerator(model, config_no_penalty)
output1 = gen1.generate(input_tensor)
print(tokenizer.decode(output1[0].tolist()))

print("\nWITH repetition penalty (1.2):")
gen2 = TextGenerator(model, config_with_penalty)
output2 = gen2.generate(input_tensor)
print(tokenizer.decode(output2[0].tolist()))

## 8. Key Takeaways

Congratulations on completing the LLM from Scratch series! Let's recap:

### Training:

1. **Training objective**: Predict next token with cross-entropy loss
2. **LR scheduling**: Warmup + cosine decay for stable training
3. **Gradient accumulation**: Train with large effective batch sizes
4. **Best practices**: Checkpointing, validation, gradient clipping
5. **Monitoring**: Track loss, perplexity, learning rate

### Generation:

1. **Greedy decoding**: Fast but repetitive
2. **Sampling**: Diverse but can be incoherent
3. **Temperature**: Control randomness (0.5-1.5 typical)
4. **Top-K**: Sample from top K tokens (50-100 typical)
5. **Top-P (Nucleus)**: Adaptive vocabulary size (0.9-0.95 typical)
6. **Beam search**: Multiple hypotheses, better quality
7. **Repetition penalty**: Reduce repetitive text (1.1-1.5 typical)

### What We Built:

1. **Chapter 1**: Tokenization (BPE encoding)
2. **Chapter 2**: Attention mechanism (multi-head self-attention)
3. **Chapter 3**: Transformer blocks (attention + FFN)
4. **Chapter 4**: Complete GPT model (embeddings + transformer + output)
5. **Chapter 5**: Training and generation (this chapter!)

## Next Steps

To go further:

1. **More data**: Train on larger datasets (Wikipedia, books, web text)
2. **Bigger models**: Scale up to GPT-2/GPT-3 sizes
3. **Fine-tuning**: Adapt to specific tasks (instruction following, chat)
4. **Optimization**: Mixed precision (FP16), distributed training
5. **Advanced techniques**: RLHF, PPO, DPO for alignment

## Congratulations!

You've built a complete GPT-style language model from scratch and understand:
- Tokenization and embeddings
- Attention mechanisms
- Transformer architecture
- Training procedures
- Text generation strategies

This is the foundation that powers ChatGPT, GPT-4, and other modern LLMs!

---

## Further Reading

### Training:
- [Adam: A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980)
- [Decoupled Weight Decay Regularization](https://arxiv.org/abs/1711.05101) (AdamW)
- [On Layer Normalization in Transformers](https://arxiv.org/abs/2002.04745)

### Generation:
- [The Curious Case of Neural Text Degeneration](https://arxiv.org/abs/1904.09751) (Nucleus sampling)
- [Hierarchical Neural Story Generation](https://arxiv.org/abs/1805.04833) (Top-k sampling)
- [CTRL: A Conditional Transformer Language Model](https://arxiv.org/abs/1909.05858) (Repetition penalty)

### Advanced:
- [Training language models to follow instructions with human feedback](https://arxiv.org/abs/2203.02155) (InstructGPT)
- [Constitutional AI](https://arxiv.org/abs/2212.08073) (Anthropic's approach)
- [Direct Preference Optimization](https://arxiv.org/abs/2305.18290) (DPO)