# Text Generation Experiments

This notebook provides an interactive guide to understanding this component of GPT.


In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add project root to path
project_root = os.path.dirname(os.path.dirname(os.path.abspath('')))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

# Import generation utilities
from src.model.gpt import GPTModel
from src.config import GPTConfig
from src.generation.generate import generate_text
from src.data.tokenizer import get_tokenizer
import tiktoken

## Text Generation Experiments

This notebook explores different text generation strategies: greedy decoding, temperature sampling, and top-k sampling.

### 1. Load or Create a Model

In [None]:
# Try to load a trained model, or create a new one
checkpoint_path = os.path.join(project_root, "checkpoints", "best_model.pt")
if os.path.exists(checkpoint_path):
    print(f"Loading trained model from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
    
    if 'config' in checkpoint:
        config = GPTConfig(**checkpoint['config'])
    else:
        config = GPTConfig(
            vocab_size=50257,
            context_length=128,
            embedding_dimension=256,
            number_of_heads=4,
            number_of_layers=4
        )
    
    model = GPTModel(config)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    print("Trained model loaded!")
else:
    print("No trained model found. Creating a new untrained model...")
    config = GPTConfig(
        vocab_size=50257,
        context_length=128,
        embedding_dimension=256,
        number_of_heads=4,
        number_of_layers=4
    )
    model = GPTModel(config)
    model.eval()
    print("Note: Untrained model will generate random text")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Using device: {device}")

In [None]:
# Initialize tokenizer
tokenizer = get_tokenizer("gpt2")
print(f"Tokenizer vocabulary size: {tokenizer.n_vocab}")

### 2. Basic Text Generation

In [None]:
# Generate text with default settings
prompt = "Once upon a time"
input_ids = tokenizer.encode(prompt)

print(f"Prompt: '{prompt}'")
print(f"Input tokens: {input_ids}")

# Generate
output_ids = generate_text(
    model,
    input_ids,
    maximum_new_tokens=20,
    temperature=1.0,
    top_k_tokens=None
)

output_text = tokenizer.decode(output_ids)
print(f"\nGenerated text: {output_text}")
print(f"Generated tokens: {len(output_ids) - len(input_ids)} new tokens")

### 3. Temperature Sampling

Temperature controls the randomness of generation. Lower temperature = more deterministic, higher = more random.

In [None]:
# Compare different temperature settings
prompt = "The cat sat"
input_ids = tokenizer.encode(prompt)
temperatures = [0.1, 0.5, 1.0, 1.5, 2.0]

print(f"Prompt: '{prompt}'\n")
print("=" * 60)

for temp in temperatures:
    output_ids = generate_text(
        model,
        input_ids,
        maximum_new_tokens=15,
        temperature=temp,
        top_k_tokens=None
    )
    output_text = tokenizer.decode(output_ids)
    print(f"Temperature {temp:3.1f}: {output_text}")
    print()

In [None]:
# Visualize how temperature affects the probability distribution
model.eval()
with torch.no_grad():
    # Get logits for a single token
    test_input = torch.tensor([[tokenizer.encode("The")[0]]], device=device)
    logits = model(test_input)[0, -1, :]  # [vocab_size]
    
    # Get top-20 tokens
    top_k = 20
    top_probs_original, top_indices = torch.topk(torch.softmax(logits, dim=-1), top_k)
    top_tokens = [tokenizer.decode([idx.item()]) for idx in top_indices]
    
    # Apply different temperatures
    temperatures = [0.1, 0.5, 1.0, 2.0]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, temp in enumerate(temperatures):
        scaled_logits = logits / temp
        probs = torch.softmax(scaled_logits, dim=-1)
        top_probs = probs[top_indices]
        
        axes[idx].bar(range(len(top_probs)), top_probs.cpu().numpy())
        axes[idx].set_xticks(range(len(top_tokens)))
        axes[idx].set_xticklabels(top_tokens, rotation=45, ha='right', fontsize=8)
        axes[idx].set_ylabel('Probability', fontsize=10)
        axes[idx].set_title(f'Temperature = {temp}', fontsize=12, fontweight='bold')
        axes[idx].grid(True, alpha=0.3, axis='y')
    
    plt.suptitle('Effect of Temperature on Token Probabilities', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

### 4. Top-K Sampling

Top-k sampling restricts sampling to the k most likely tokens, reducing the chance of generating unlikely tokens.

In [None]:
# Compare different top-k values
prompt = "In a far away land"
input_ids = tokenizer.encode(prompt)
top_k_values = [None, 10, 50, 100]

print(f"Prompt: '{prompt}'\n")
print("=" * 60)

for top_k in top_k_values:
    output_ids = generate_text(
        model,
        input_ids,
        maximum_new_tokens=20,
        temperature=0.8,
        top_k_tokens=top_k
    )
    output_text = tokenizer.decode(output_ids)
    k_str = "All tokens" if top_k is None else f"Top-{top_k}"
    print(f"{k_str:15s}: {output_text}")
    print()

### 5. Greedy Decoding (Temperature = 0)

Greedy decoding always picks the most likely token. This is deterministic but can be repetitive.

In [None]:
# Greedy decoding (temperature = 0)
prompt = "The little girl"
input_ids = tokenizer.encode(prompt)

print(f"Prompt: '{prompt}'\n")

# Greedy
output_ids_greedy = generate_text(
    model,
    input_ids,
    maximum_new_tokens=20,
    temperature=0.0,  # Greedy
    top_k_tokens=None
)
output_text_greedy = tokenizer.decode(output_ids_greedy)
print(f"Greedy (temp=0): {output_text_greedy}")

# With temperature
output_ids_temp = generate_text(
    model,
    input_ids,
    maximum_new_tokens=20,
    temperature=0.8,
    top_k_tokens=None
)
output_text_temp = tokenizer.decode(output_ids_temp)
print(f"Sampling (temp=0.8): {output_text_temp}")

### 6. Multiple Generations

Generate multiple samples to see the diversity of outputs.

In [None]:
# Generate multiple samples
prompt = "Once upon a time"
input_ids = tokenizer.encode(prompt)
num_samples = 5

print(f"Prompt: '{prompt}'\n")
print("Generating 5 samples with temperature=0.8:\n")
print("=" * 60)

for i in range(num_samples):
    output_ids = generate_text(
        model,
        input_ids,
        maximum_new_tokens=25,
        temperature=0.8,
        top_k_tokens=50
    )
    output_text = tokenizer.decode(output_ids)
    print(f"Sample {i+1}: {output_text}\n")

### 7. Comparing Generation Strategies

Let's compare different combinations of temperature and top-k.

In [None]:
# Compare different strategies
prompt = "The sun was shining"
input_ids = tokenizer.encode(prompt)

strategies = [
    ("Greedy", 0.0, None),
    ("Low temp", 0.3, None),
    ("Medium temp", 0.8, None),
    ("High temp", 1.5, None),
    ("Top-k=10", 0.8, 10),
    ("Top-k=50", 0.8, 50),
    ("Top-k=100", 0.8, 100),
]

print(f"Prompt: '{prompt}'\n")
print("=" * 70)

for name, temp, top_k in strategies:
    output_ids = generate_text(
        model,
        input_ids,
        maximum_new_tokens=20,
        temperature=temp,
        top_k_tokens=top_k
    )
    output_text = tokenizer.decode(output_ids)
    print(f"{name:15s} (temp={temp}, top_k={top_k}): {output_text}")
    print()

### 8. Understanding Generation Step by Step

Let's manually trace through one generation step to understand how it works.

In [None]:
# Manual generation step
prompt = "Hello"
input_ids = tokenizer.encode(prompt)
input_tensor = torch.tensor([input_ids], device=device)

model.eval()
with torch.no_grad():
    # Forward pass
    logits = model(input_tensor)  # [1, seq_len, vocab_size]
    
    # Get logits for last position
    last_logits = logits[0, -1, :]  # [vocab_size]
    
    # Apply temperature
    temperature = 0.8
    scaled_logits = last_logits / temperature
    
    # Apply top-k
    top_k = 50
    top_k_values, top_k_indices = torch.topk(scaled_logits, min(top_k, len(scaled_logits)))
    threshold = top_k_values[-1]
    scaled_logits = scaled_logits.masked_fill(scaled_logits < threshold, float('-inf'))
    
    # Softmax to get probabilities
    probs = torch.softmax(scaled_logits, dim=-1)
    
    # Get top-10 most likely tokens
    top_10_probs, top_10_indices = torch.topk(probs, 10)
    top_10_tokens = [tokenizer.decode([idx.item()]) for idx in top_10_indices]
    
    print(f"Prompt: '{prompt}'")
    print(f"\nTop 10 most likely next tokens:")
    for i, (prob, token) in enumerate(zip(top_10_probs, top_10_tokens)):
        print(f"  {i+1:2d}. '{token}' (probability: {prob.item():.4f})")
    
    # Sample from distribution
    next_token = torch.multinomial(probs, num_samples=1)
    next_token_str = tokenizer.decode([next_token.item()])
    print(f"\nSampled token: '{next_token_str}' (ID: {next_token.item()})")