# Training GPT on Shakespeare

Time to bring our model to life by training it on text!

In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load our GPT implementation (copy from previous notebook)
# ... (include all the previous code here)

Using device: cpu


## Load and Prepare Shakespeare Data

We'll use character-level tokenization for simplicity.

In [14]:
# Load the Shakespeare text
with open('../data/tiny_shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(f"Dataset length: {len(text):,} characters")
print(f"First 200 characters:\n{text[:200]}")

# Get all unique characters
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"\nVocabulary size: {vocab_size}")
print(f"Characters: {''.join(chars)}")

# Create character to index mappings
stoi = {ch: i for i, ch in enumerate(chars)}  # string to index
itos = {i: ch for i, ch in enumerate(chars)}  # index to string

# Tokenize functions
def encode(s):
    return [stoi[c] for c in s]

def decode(tokens):
    return ''.join([itos[i] for i in tokens])

# Test tokenization
test_string = "Hello World!"
encoded = encode(test_string)
decoded = decode(encoded)
print(f"\nTokenization test:")
print(f"Original: {test_string}")
print(f"Encoded: {encoded}")
print(f"Decoded: {decoded}")

Dataset length: 1,115,394 characters
First 200 characters:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you

Vocabulary size: 65
Characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz

Tokenization test:
Original: Hello World!
Encoded: [20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]
Decoded: Hello World!


## Create Training and Validation Splits

In [15]:
# Encode the entire dataset
data = torch.tensor(encode(text), dtype=torch.long)
print(f"Encoded dataset shape: {data.shape}")

# Split into train and validation
n = int(0.9 * len(data))  # 90% train, 10% val
train_data = data[:n]
val_data = data[n:]

print(f"Train size: {len(train_data):,}")
print(f"Val size: {len(val_data):,}")

# Create data loader
def get_batch(split, batch_size=32, block_size=128):
    """Generate a batch of inputs and targets"""
    data = train_data if split == 'train' else val_data
    
    # Random starting positions
    ix = torch.randint(len(data) - block_size, (batch_size,))
    
    # Get sequences
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    
    return x.to(device), y.to(device)

# Test batch generation
x, y = get_batch('train', batch_size=4, block_size=8)
print(f"\nBatch shapes - X: {x.shape}, Y: {y.shape}")

# Show an example
print("\nExample from batch:")
print(f"Input:  {decode(x[0].tolist())}")
print(f"Target: {decode(y[0].tolist())}")

Encoded dataset shape: torch.Size([1115394])
Train size: 1,003,854
Val size: 111,540

Batch shapes - X: torch.Size([4, 8]), Y: torch.Size([4, 8])

Example from batch:
Input:   may los
Target: may lose


## Initialize Model and Training Setup

In [16]:
# Model hyperparameters
batch_size = 64
block_size = 256  # context length
learning_rate = 3e-4
max_iters = 5000
eval_interval = 500
eval_iters = 200

# Initialize model
model = GPT(
    vocab_size=vocab_size,
    d_model=384,      # Small model
    n_heads=6,
    n_layers=6,
    max_len=block_size,
    dropout=0.2
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params:,}")

# Create optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Model parameters: 10,697,472


## Training Loop with Loss Tracking

In [None]:
@torch.no_grad()
def estimate_loss():
    """Estimate loss on train and val sets"""
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, batch_size, block_size)
            
            # Create causal mask
            seq_len = X.shape[1]
            mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).to(device)
            
            logits = model(X, mask)
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = Y.view(B*T)
            loss = F.cross_entropy(logits, targets)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

# Training loop
train_losses = []
val_losses = []

print("Starting training...")
for iter in tqdm(range(max_iters)):
    # Sample batch
    xb, yb = get_batch('train', batch_size, block_size)
    
    # Create causal mask
    seq_len = xb.shape[1]
    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).to(device)
    
    # Forward pass
    logits = model(xb, mask)
    B, T, C = logits.shape
    logits = logits.view(B*T, C)
    targets = yb.view(B*T)
    loss = F.cross_entropy(logits, targets)
    
    # Backward pass
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    
    # Gradient clipping (important for stability!)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    
    optimizer.step()
    
    # Evaluate periodically
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        train_losses.append(losses['train'])
        val_losses.append(losses['val'])
        print(f"\nStep {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

Starting training...


  0%|                                                                                       | 1/5000 [29:04<2422:40:17, 1744.67s/it]


Step 0: train loss 3.6353, val loss 3.6677


 10%|████████▌                                                                            | 501/5000 [1:48:37<130:34:01, 104.48s/it]


Step 500: train loss 2.1655, val loss 2.1882


 20%|████████████████▍                                                                 | 1001/5000 [5:08:47<3511:28:28, 3161.12s/it]


Step 1000: train loss 1.6847, val loss 1.8418


 30%|█████████████████████████▏                                                          | 1501/5000 [13:12:34<99:06:25, 101.97s/it]


Step 1500: train loss 1.4792, val loss 1.6842


 40%|█████████████████████████████████▏                                                 | 2001/5000 [14:31:52<301:26:15, 361.85s/it]


Step 2000: train loss 1.3728, val loss 1.6030


 50%|██████████████████████████████████████████                                          | 2501/5000 [15:08:24<80:21:59, 115.77s/it]


Step 2500: train loss 1.3123, val loss 1.5621


 60%|██████████████████████████████████████████████████▍                                 | 3001/5000 [17:13:45<62:13:33, 112.06s/it]


Step 3000: train loss 1.2615, val loss 1.5349


 70%|██████████████████████████████████████████████████████████▊                         | 3501/5000 [17:52:07<67:35:32, 162.33s/it]


Step 3500: train loss 1.2221, val loss 1.5251


 71%|████████████████████████████████████████████████████████████▊                         | 3535/5000 [17:53:43<1:08:04,  2.79s/it]

## Visualize Training Progress

In [None]:
# Plot training curves
plt.figure(figsize=(10, 6))
steps = np.arange(0, max_iters, eval_interval)
if len(steps) < len(train_losses):
    steps = np.append(steps, max_iters-1)

plt.plot(steps, train_losses, label='Train Loss')
plt.plot(steps, val_losses, label='Val Loss')
plt.xlabel('Training Steps')
plt.ylabel('Cross Entropy Loss')
plt.title('GPT Training Progress')
plt.legend()
plt.grid(True)
plt.show()

print(f"Final train loss: {train_losses[-1]:.4f}")
print(f"Final val loss: {val_losses[-1]:.4f}")

## Text Generation

Now the fun part - let's generate some Shakespeare!

In [12]:
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):
    """
    Generate text from the model
    idx: initial context tokens [batch_size, seq_len]
    """
    model.eval()
    for _ in range(max_new_tokens):
        # Crop context to block_size
        idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
        
        # Get predictions
        seq_len = idx_cond.shape[1]
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).to(device)
        logits = model(idx_cond, mask)
        
        # Focus on last time step
        logits = logits[:, -1, :] / temperature
        
        # Optional top-k sampling
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        
        # Sample from distribution
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        
        # Append to sequence
        idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

# Generate from different prompts
prompts = [
    "\n",  # Empty prompt
    "ROMEO:",
    "To be or not to be",
    "What is"
]

print("Generated samples:\n" + "="*50)
for prompt in prompts:
    # Encode prompt
    context = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
    
    # Generate
    generated = generate(model, context, max_new_tokens=200, temperature=0.8, top_k=40)
    
    print(f"\nPrompt: {prompt}")
    print(f"Generated: {decode(generated[0].tolist())}")
    print("-"*50)

Generated samples:

Prompt: 

Generated: 
s
kfD'xI$?y $fYNxnH$lHXdcn:;o!skSSqlxjENWabQlqZ!$OUYESl- GUXLhH 3tNZYHNUFYJOxXaSu?KwR&DOcce$LLDB3St O&Bw3!tKR!tSxqfwhhmJHTiWJIBaKHh$jdi;N,BOx;hfnnf3X&xdHHEvy'x?yR3M:s,TSEiETCBBDwUUf$wGh&Y?BZ,vmLtawCiX
--------------------------------------------------

Prompt: ROMEO:
Generated: ROMEO:dkmxEy.lEgigwfd,kBZuMYxEMHZ,Y,qS,!waEuUfd!HaxUbYLFt d:HttTdNd?ZOatQEd;'NN'nmnROTNHl'HkVnVCvLyRLV&sw$t;SSdNOBBhN:in$p?dwdKBFROJ&t$lXqvJ!BhwOSxlSw3JIE.h&ESdatZJNhd&Wq:UnWaTQ$aZyOBHk?tn3:,cfHs:HU!M,w,xcW
--------------------------------------------------

Prompt: To be or not to be
Generated: To be or not to beoZK!Xg?sUXa,$JEMWwk:bXbikN:jdqKEwBYw KWIYTJdgwf'EOqEUZlaGRftZXQguuNuuEG,NBXJZvHwUV g;xrg$xBHBdJ!Ewv&x: JRfIUOY !SBKEB$HdnDOUd?daEiOcahlhOdlxYq!qd:XaBVCTxdUwCfNw?fqEO:UZBSPLwGt&Ylpsywat gfUy3bUTO?eY3aU
--------------------------------------------------

Prompt: What is
Generated: What isUUZEd3D.U?ox;vTNE.XiddZnJaEM.SlHqloWimWH$!cS3TQYqJEWlUUlUSWWftZVlXhUBXED?targ