# Module 07: Sampling & Generation - Interactive Notebook

This notebook provides a hands-on walkthrough of text generation and sampling strategies for transformer language models.

**Topics Covered:**
1. Greedy sampling (deterministic)
2. Temperature sampling (controlled randomness)
3. Top-k sampling (vocabulary filtering)
4. Top-p/nucleus sampling (dynamic filtering)
5. Combined sampling strategies
6. TextGenerator interface
7. Comparison and visualization
8. Interactive generation examples

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import sys
from collections import Counter

# Add project root to path
sys.path.insert(0, str(Path.cwd().parent.parent.parent))

from tiny_transformer.sampling import (
    greedy_sample,
    temperature_sample,
    top_k_sample,
    top_p_sample,
    combined_sample,
    TextGenerator,
    GeneratorConfig
)
from tiny_transformer.model import TinyTransformerLM, get_model_config
from tiny_transformer.training import CharTokenizer, set_seed

set_seed(42)
print("✓ Imports successful")

## 1. Understanding Sampling Basics

Language models output **logits** (unnormalized scores) for each token in the vocabulary. To generate text, we need to convert these logits into actual token selections.

Let's start with a simple example using mock logits.

In [None]:
# Create simple example logits
# Higher logits = higher probability after softmax
vocab = ['the', 'cat', 'dog', 'sat', 'ran', 'jumped', 'on', 'quickly']
vocab_size = len(vocab)

# Example logits for next word prediction
logits = torch.tensor([[
    1.5,   # 'the'
    2.8,   # 'cat' - highest logit
    2.5,   # 'dog'
    0.5,   # 'sat'
    1.0,   # 'ran'
    0.8,   # 'jumped'
    0.3,   # 'on'
    0.2    # 'quickly'
]])

# Convert to probabilities
probs = F.softmax(logits, dim=-1)

# Visualize
plt.figure(figsize=(10, 5))
plt.bar(vocab, probs[0].numpy())
plt.xlabel('Token')
plt.ylabel('Probability')
plt.title('Token Probability Distribution')
plt.xticks(rotation=45)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Token probabilities:")
for token, prob in zip(vocab, probs[0]):
    print(f"  {token:10s}: {prob:.4f}")

## 2. Greedy Sampling

**Greedy sampling** always picks the token with the highest probability.

- **Pros:** Deterministic, reproducible
- **Cons:** Repetitive, lacks diversity

In [None]:
# Greedy sampling: argmax
greedy_token = greedy_sample(logits)
print(f"Greedy selection: {vocab[greedy_token.item()]}")
print(f"Selected probability: {probs[0, greedy_token].item():.4f}")

# Test determinism
print("\nDeterminism test (10 samples):")
samples = [greedy_sample(logits).item() for _ in range(10)]
print(f"All samples identical: {len(set(samples)) == 1}")
print(f"All samples: {[vocab[s] for s in samples]}")

## 3. Temperature Sampling

**Temperature** controls randomness by scaling logits before softmax:

$$\text{probs} = \text{softmax}(\text{logits} / T)$$

- **T → 0**: Becomes greedy (deterministic)
- **T = 1**: Original distribution
- **T > 1**: More random (flatter distribution)

In [None]:
# Compare different temperatures
temperatures = [0.1, 0.5, 1.0, 2.0]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, temp in enumerate(temperatures):
    # Compute temperature-scaled probabilities
    scaled_probs = F.softmax(logits / temp, dim=-1)
    
    axes[idx].bar(vocab, scaled_probs[0].numpy())
    axes[idx].set_xlabel('Token')
    axes[idx].set_ylabel('Probability')
    axes[idx].set_title(f'Temperature = {temp}')
    axes[idx].tick_params(axis='x', rotation=45)
    axes[idx].grid(True, alpha=0.3)
    
    # Calculate entropy (measure of randomness)
    entropy = -(scaled_probs * torch.log(scaled_probs + 1e-10)).sum().item()
    axes[idx].text(0.02, 0.98, f'Entropy: {entropy:.3f}', 
                   transform=axes[idx].transAxes, verticalalignment='top',
                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print("Observation: Lower temperature → more peaked distribution → less randomness")
print("             Higher temperature → flatter distribution → more randomness")

In [None]:
# Sample with different temperatures
set_seed(42)

print("Temperature sampling (20 samples each):\n")
for temp in [0.1, 0.5, 1.0, 2.0]:
    samples = []
    for _ in range(20):
        token = temperature_sample(logits, temperature=temp)
        samples.append(vocab[token.item()])
    
    # Count frequencies
    freq = Counter(samples)
    
    print(f"T = {temp}:")
    print(f"  Unique tokens: {len(freq)}/{vocab_size}")
    print(f"  Most common: {freq.most_common(3)}")
    print()

## 4. Top-K Sampling

**Top-k sampling** filters the vocabulary to the K tokens with highest probability, then samples from them.

- Prevents sampling from very low-probability tokens
- Fixed vocabulary size K
- Common values: K = 10, 50, 100

In [None]:
# Visualize top-k filtering
k_values = [1, 2, 4, 8]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

original_probs = F.softmax(logits, dim=-1)

for idx, k in enumerate(k_values):
    # Get top-k tokens
    topk_probs, topk_indices = torch.topk(original_probs, k, dim=-1)
    
    # Create filtered distribution (zeros for non-top-k)
    filtered_probs = torch.zeros_like(original_probs)
    filtered_probs.scatter_(1, topk_indices, topk_probs)
    
    # Renormalize
    filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
    
    axes[idx].bar(vocab, filtered_probs[0].numpy(), alpha=0.7, label='Top-k')
    axes[idx].bar(vocab, original_probs[0].numpy(), alpha=0.3, label='Original')
    axes[idx].set_xlabel('Token')
    axes[idx].set_ylabel('Probability')
    axes[idx].set_title(f'Top-k = {k}')
    axes[idx].tick_params(axis='x', rotation=45)
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Sample with different k values
set_seed(42)

print("Top-k sampling (20 samples each):\n")
for k in [1, 2, 4, 8]:
    samples = []
    for _ in range(20):
        token = top_k_sample(logits, k=k, temperature=1.0)
        samples.append(vocab[token.item()])
    
    freq = Counter(samples)
    
    print(f"k = {k}:")
    print(f"  Unique tokens: {len(freq)}/{vocab_size}")
    print(f"  Most common: {freq.most_common(3)}")
    print()

## 5. Top-P (Nucleus) Sampling

**Top-p sampling** (nucleus sampling) selects the smallest set of tokens whose cumulative probability exceeds p.

- Dynamic vocabulary size (adapts to distribution)
- More principled than top-k
- Common values: p = 0.9, 0.95

In [None]:
# Visualize top-p filtering
p_values = [0.5, 0.7, 0.9, 0.95]

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

original_probs = F.softmax(logits, dim=-1)

for idx, p in enumerate(p_values):
    # Sort probabilities
    sorted_probs, sorted_indices = torch.sort(original_probs, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
    
    # Find cutoff
    cutoff_mask = cumulative_probs <= p
    # Include at least one token
    cutoff_mask[..., 0] = True
    
    # Filter
    filtered_probs = torch.zeros_like(original_probs)
    filtered_probs.scatter_(1, sorted_indices, sorted_probs * cutoff_mask.float())
    
    # Renormalize
    filtered_probs = filtered_probs / (filtered_probs.sum(dim=-1, keepdim=True) + 1e-10)
    
    num_kept = cutoff_mask.sum().item()
    
    axes[idx].bar(vocab, filtered_probs[0].numpy(), alpha=0.7, label=f'Top-p (kept {num_kept})')
    axes[idx].bar(vocab, original_probs[0].numpy(), alpha=0.3, label='Original')
    axes[idx].set_xlabel('Token')
    axes[idx].set_ylabel('Probability')
    axes[idx].set_title(f'Top-p = {p} (nucleus size: {num_kept})')
    axes[idx].tick_params(axis='x', rotation=45)
    axes[idx].legend()
    axes[idx].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Observation: Top-p adapts to the distribution shape")
print("            Peaked distributions → fewer tokens")
print("            Flat distributions → more tokens")

In [None]:
# Sample with different p values
set_seed(42)

print("Top-p sampling (20 samples each):\n")
for p in [0.5, 0.7, 0.9, 0.95]:
    samples = []
    for _ in range(20):
        token = top_p_sample(logits, p=p, temperature=1.0)
        samples.append(vocab[token.item()])
    
    freq = Counter(samples)
    
    print(f"p = {p}:")
    print(f"  Unique tokens: {len(freq)}/{vocab_size}")
    print(f"  Most common: {freq.most_common(3)}")
    print()

## 6. Combined Sampling

**Combined sampling** applies multiple filters in sequence:
1. Temperature scaling
2. Top-k filtering
3. Top-p filtering

This is the most common approach in production systems (e.g., GPT, Claude).

In [None]:
# Compare different sampling configurations
configs = [
    {"name": "Greedy", "temp": 0.0, "k": None, "p": None},
    {"name": "Low temp", "temp": 0.5, "k": None, "p": None},
    {"name": "Temp + Top-k", "temp": 0.8, "k": 4, "p": None},
    {"name": "Temp + Top-p", "temp": 0.8, "k": None, "p": 0.9},
    {"name": "All combined", "temp": 0.8, "k": 6, "p": 0.95},
]

set_seed(42)

print("Sampling strategy comparison (30 samples each):\n")
for config in configs:
    samples = []
    for _ in range(30):
        if config["temp"] == 0.0:
            token = greedy_sample(logits)
        else:
            token = combined_sample(
                logits, 
                temperature=config["temp"],
                k=config["k"],
                p=config["p"]
            )
        samples.append(vocab[token.item()])
    
    freq = Counter(samples)
    
    print(f"{config['name']}:")
    print(f"  Config: temp={config['temp']}, k={config['k']}, p={config['p']}")
    print(f"  Unique tokens: {len(freq)}/{vocab_size}")
    print(f"  Distribution: {freq.most_common()}")
    print()

## 7. TextGenerator - High-Level Interface

The `TextGenerator` class provides a convenient interface for autoregressive text generation.

In [None]:
# Create a tiny model for generation demos
text = "hello world this is a test of the text generation system. " * 10

# Create tokenizer
tokenizer = CharTokenizer()
tokenizer.fit(text)

print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Characters: {sorted(tokenizer.vocab.keys())}")

In [None]:
# Create and train a tiny model (just for demo purposes)
# In practice, you'd load a pre-trained checkpoint

from tiny_transformer.training import TextDataset, Trainer, TrainerConfig
from torch.utils.data import DataLoader

# Create tiny model
config = get_model_config('tiny')
model = TinyTransformerLM(
    vocab_size=tokenizer.vocab_size,
    **config
)

# Quick training (just to get sensible outputs)
tokens = tokenizer.encode(text)
dataset = TextDataset(tokens, seq_len=32)
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)

trainer_config = TrainerConfig(
    learning_rate=1e-3,
    max_steps=100,
    log_interval=20,
    device='cpu'
)

trainer = Trainer(model, train_loader, config=trainer_config)
print("Training for 100 steps...")
trainer.train()
print("\n✓ Training complete")

In [None]:
# Use TextGenerator with different configurations
model.eval()

# Starting prompt
prompt = "hello "
prompt_tokens = torch.tensor([tokenizer.encode(prompt)])

print(f"Prompt: '{prompt}'\n")
print("="*70)

# Try different sampling strategies
strategies = [
    {"name": "Greedy", "config": GeneratorConfig(max_new_tokens=30, do_sample=False)},
    {"name": "Temperature=0.5", "config": GeneratorConfig(max_new_tokens=30, temperature=0.5)},
    {"name": "Temperature=1.0", "config": GeneratorConfig(max_new_tokens=30, temperature=1.0)},
    {"name": "Top-k=5", "config": GeneratorConfig(max_new_tokens=30, temperature=0.8, top_k=5)},
    {"name": "Top-p=0.9", "config": GeneratorConfig(max_new_tokens=30, temperature=0.8, top_p=0.9)},
    {"name": "Combined", "config": GeneratorConfig(max_new_tokens=30, temperature=0.8, top_k=10, top_p=0.95)},
]

for strategy in strategies:
    generator = TextGenerator(model, strategy["config"], device='cpu')
    
    with torch.no_grad():
        output_tokens = generator.generate(prompt_tokens)
    
    output_text = tokenizer.decode(output_tokens[0].tolist())
    
    print(f"\n{strategy['name']}:")
    print(f"  {output_text}")

print("\n" + "="*70)

## 8. Diversity vs Quality Trade-off

Different sampling strategies offer different trade-offs between:
- **Quality**: Coherence, grammaticality
- **Diversity**: Variety, creativity

Let's measure diversity quantitatively.

In [None]:
def measure_diversity(model, generator_config, prompt_tokens, tokenizer, num_samples=10):
    """Measure diversity of generated samples."""
    generator = TextGenerator(model, generator_config, device='cpu')
    
    texts = []
    with torch.no_grad():
        for _ in range(num_samples):
            output_tokens = generator.generate(prompt_tokens)
            text = tokenizer.decode(output_tokens[0].tolist())
            texts.append(text)
    
    # Calculate metrics
    unique_texts = len(set(texts))
    avg_length = np.mean([len(t) for t in texts])
    
    # Character-level diversity
    all_chars = ''.join(texts)
    unique_chars = len(set(all_chars))
    
    return {
        'unique_samples': unique_texts,
        'total_samples': num_samples,
        'uniqueness_ratio': unique_texts / num_samples,
        'avg_length': avg_length,
        'unique_chars': unique_chars,
        'samples': texts[:3]  # First 3 for inspection
    }

# Compare diversity
configs_to_test = [
    ("Greedy", GeneratorConfig(max_new_tokens=20, do_sample=False)),
    ("Temp=0.3", GeneratorConfig(max_new_tokens=20, temperature=0.3)),
    ("Temp=0.8", GeneratorConfig(max_new_tokens=20, temperature=0.8)),
    ("Temp=1.5", GeneratorConfig(max_new_tokens=20, temperature=1.5)),
    ("Top-p=0.9", GeneratorConfig(max_new_tokens=20, temperature=0.8, top_p=0.9)),
]

results = []
for name, config in configs_to_test:
    print(f"Testing {name}...")
    metrics = measure_diversity(model, config, prompt_tokens, tokenizer, num_samples=10)
    results.append((name, metrics))

# Display results
print("\n" + "="*70)
print("DIVERSITY ANALYSIS")
print("="*70 + "\n")

for name, metrics in results:
    print(f"{name}:")
    print(f"  Unique samples: {metrics['unique_samples']}/{metrics['total_samples']}")
    print(f"  Uniqueness ratio: {metrics['uniqueness_ratio']:.2%}")
    print(f"  Unique characters: {metrics['unique_chars']}")
    print(f"  Sample outputs:")
    for i, sample in enumerate(metrics['samples'], 1):
        print(f"    {i}. {sample[:50]}...")
    print()

In [None]:
# Visualize diversity vs temperature
temperatures = [0.1, 0.3, 0.5, 0.7, 0.9, 1.1, 1.3, 1.5]
diversity_scores = []

for temp in temperatures:
    config = GeneratorConfig(max_new_tokens=20, temperature=temp)
    metrics = measure_diversity(model, config, prompt_tokens, tokenizer, num_samples=20)
    diversity_scores.append(metrics['uniqueness_ratio'])

plt.figure(figsize=(10, 6))
plt.plot(temperatures, diversity_scores, marker='o', linewidth=2, markersize=8)
plt.xlabel('Temperature')
plt.ylabel('Uniqueness Ratio')
plt.title('Diversity vs Temperature')
plt.grid(True, alpha=0.3)
plt.axhline(y=1.0, color='g', linestyle='--', alpha=0.5, label='Perfect diversity')
plt.legend()
plt.tight_layout()
plt.show()

print("Observation: Higher temperature → more diversity (but possibly less coherence)")

## 9. EOS Token Handling

TextGenerator can stop generation early when an end-of-sequence (EOS) token is generated.

In [None]:
# Demonstrate EOS handling
# Let's use newline as EOS
eos_token_id = tokenizer.vocab.get('\n', None)

if eos_token_id is not None:
    print(f"Using newline as EOS token (id={eos_token_id})\n")
    
    # Generate with EOS handling
    config = GeneratorConfig(
        max_new_tokens=50,
        temperature=0.8,
        eos_token_id=eos_token_id
    )
    
    generator = TextGenerator(model, config, device='cpu')
    
    with torch.no_grad():
        output_tokens = generator.generate(prompt_tokens)
    
    output_text = tokenizer.decode(output_tokens[0].tolist())
    
    print(f"Generated text (stops at newline):")
    print(f"  {output_text}")
    print(f"\nGenerated {len(output_tokens[0])} tokens (max was {len(prompt_tokens[0]) + 50})")
else:
    print("No newline in vocabulary, skipping EOS demo")

## 10. Batch Generation

TextGenerator supports generating multiple sequences in parallel.

In [None]:
# Create multiple prompts
prompts = [
    "hello ",
    "this is ",
    "test "
]

prompt_tokens_batch = torch.tensor([tokenizer.encode(p) for p in prompts])
print(f"Batch shape: {prompt_tokens_batch.shape}\n")

# Generate for all prompts at once
config = GeneratorConfig(
    max_new_tokens=20,
    temperature=0.8,
    top_p=0.9
)

generator = TextGenerator(model, config, device='cpu')

with torch.no_grad():
    output_tokens_batch = generator.generate(prompt_tokens_batch)

print("Batch generation results:\n")
for i, (prompt, output_tokens) in enumerate(zip(prompts, output_tokens_batch)):
    output_text = tokenizer.decode(output_tokens.tolist())
    print(f"{i+1}. Prompt: '{prompt}'")
    print(f"   Output: '{output_text}'\n")

## Summary

In this notebook, we explored:

1. **Greedy Sampling** - Deterministic, always picks highest probability token
2. **Temperature Sampling** - Controls randomness by scaling logits
3. **Top-K Sampling** - Filters to K highest probability tokens
4. **Top-P (Nucleus) Sampling** - Dynamically filters by cumulative probability
5. **Combined Sampling** - Uses temperature + top-k + top-p together
6. **TextGenerator** - High-level interface for text generation
7. **Diversity Analysis** - Trade-offs between quality and diversity
8. **EOS Handling** - Early stopping with end-of-sequence tokens
9. **Batch Generation** - Efficient parallel generation

### Key Takeaways:

- **Greedy** is fast but repetitive
- **Temperature** controls randomness (lower = more deterministic)
- **Top-k** prevents low-probability tokens (fixed size)
- **Top-p** adapts to distribution shape (variable size)
- **Combined** (temp + top-k/top-p) is most common in production
- Higher diversity often means lower coherence

### Recommended Settings:

- **Factual text**: temperature=0.3-0.5, top-p=0.9
- **Creative writing**: temperature=0.8-1.0, top-p=0.95
- **Code generation**: temperature=0.2-0.4, top-p=0.9

### Next Steps:

- Try the exercises in `exercises/exercises.py`
- Experiment with different sampling strategies
- Train on larger datasets (Shakespeare, WikiText)
- Move on to Module 08 (Engineering Practices)