# Predicting the Next Word — Try it in PyTorch

This is an **optional** hands-on companion to [Chapter 6](https://robennals.github.io/ai-explained/06-next-word-prediction). You'll build bigram models, see why lookup tables hit a wall, and train a neural network that generalizes through embeddings.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

## Bigram Model: Counting What Comes Next

The simplest next-word predictor: for every word, count what word usually follows it. This is just a big lookup table of frequencies.

In [None]:
# A tiny corpus of children's stories
corpus = [
    "the cat sat on the mat",
    "the dog sat on the rug",
    "the cat chased the dog",
    "the dog chased the cat",
    "she was very happy",
    "he was very sad",
    "she was very kind",
    "once upon a time there was a little girl",
    "once upon a time there was a little boy",
    "once upon a time there was a little cat",
    "the little girl was happy",
    "the little boy was sad",
]

# Count bigrams: how often does word B follow word A?
from collections import defaultdict, Counter

bigram_counts = defaultdict(Counter)
for sentence in corpus:
    words = sentence.split()
    for a, b in zip(words, words[1:]):
        bigram_counts[a][b] += 1

# Show predictions for some words
for word in ["the", "was", "once", "little"]:
    total = sum(bigram_counts[word].values())
    print(f"After '{word}':")
    for next_word, count in bigram_counts[word].most_common(5):
        print(f"  {next_word:>8}: {count/total:.0%}")
    print()

## Generating from a Bigram Model

We can chain predictions together to generate text. Each pair of words sounds fine, but the whole thing is a random walk through common word pairs.

In [None]:
import random

def generate_bigram(start_word, length=15):
    words = [start_word]
    current = start_word
    for _ in range(length):
        if current not in bigram_counts:
            break
        options = bigram_counts[current]
        total = sum(options.values())
        # Sample proportionally to frequency
        r = random.random() * total
        cumulative = 0
        for word, count in options.items():
            cumulative += count
            if r <= cumulative:
                words.append(word)
                current = word
                break
    return " ".join(words)

print("Generated sentences:")
for start in ["the", "once", "she"]:
    for i in range(3):
        print(f"  {generate_bigram(start)}")
    print()

print("Each word pair is reasonable, but the whole sentence is nonsense!")

## The N-gram Wall

More context should help, right? Trigrams (3 words) are better than bigrams, and 4-grams are better still. But the number of possible sequences explodes so fast that most sequences are never seen — even in huge datasets.

In [None]:
vocab_size = 50_000  # Realistic English vocabulary

print("Number of possible n-gram entries:")
print(f"{'N':>4}  {'Combinations':>20}  {'Comparison'}")
print("-" * 60)

comparisons = [
    (2, "manageable"),
    (3, "large but feasible"),
    (4, "bigger than Wikipedia"),
    (5, "more than all books ever written"),
    (10, "more than atoms in the universe"),
]

for n, comparison in comparisons:
    combos = vocab_size ** n
    print(f"{n:>4}  {combos:>20.2e}  {comparison}")

print(f"\nWith a vocab of {vocab_size:,} words, 10-grams need {vocab_size**10:.0e} entries.")
print("There aren't enough books in the world to fill that table!")

## Neural Network to the Rescue

Instead of a lookup table, use a neural network. Feed it the *embeddings* of the previous words, pass through a hidden layer, and output a probability for each word in the vocabulary.

The key advantage: the network **generalizes**. "Cat" and "dog" have similar embeddings, so a network that learns "the cat → sat" will automatically make similar predictions for "the dog" — even if it never saw that exact input.

In [None]:
# Build a tiny vocabulary and training data
sentences = [
    ["the", "cat", "sat"],
    ["the", "dog", "sat"],
    ["the", "cat", "ran"],
    ["the", "dog", "ran"],
    ["the", "bird", "flew"],
    ["a", "cat", "sat"],
    ["a", "dog", "ran"],
    ["the", "girl", "smiled"],
    ["the", "boy", "smiled"],
    ["the", "girl", "ran"],
    ["the", "boy", "ran"],
    ["she", "was", "happy"],
    ["he", "was", "happy"],
    ["she", "was", "sad"],
    ["he", "was", "sad"],
]

# Create vocabulary
all_words = sorted(set(w for s in sentences for w in s))
word2id = {w: i for i, w in enumerate(all_words)}
id2word = {i: w for w, i in word2id.items()}
vocab_size = len(all_words)
print(f"Vocabulary ({vocab_size} words): {all_words}")

# Training data: context of 2 words → predict 3rd
context_len = 2
X = torch.tensor([[word2id[w] for w in s[:context_len]] for s in sentences])
Y = torch.tensor([word2id[s[context_len]] for s in sentences])
print(f"Training examples: {len(X)}")
print(f"Example: {[id2word[i.item()] for i in X[0]]} → {id2word[Y[0].item()]}")

In [None]:
# Define the model: embedding → flatten → dense → softmax
class NextWordModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, context_len, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(context_len * embed_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        e = self.embedding(x)            # (batch, context_len, embed_dim)
        e = e.view(e.size(0), -1)        # flatten to (batch, context_len * embed_dim)
        h = F.relu(self.fc1(e))          # hidden layer with ReLU
        return self.fc2(h)               # scores for each word in vocab

torch.manual_seed(42)
model = NextWordModel(vocab_size, embed_dim=8, context_len=2, hidden_dim=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params}")
print(f"Architecture: Embedding(dim=8) → Dense(16, ReLU) → Dense({vocab_size})")

In [None]:
# Train
losses = []
for epoch in range(300):
    logits = model(X)
    loss = criterion(logits, Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    losses.append(loss.item())
    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")

plt.figure(figsize=(8, 3))
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training loss")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# Make predictions
model.eval()

test_contexts = [
    ["the", "cat"],
    ["the", "dog"],
    ["the", "bird"],
    ["the", "girl"],
    ["the", "boy"],
    ["she", "was"],
    ["he", "was"],
]

print("Predictions:")
with torch.no_grad():
    for ctx in test_contexts:
        ids = torch.tensor([[word2id[w] for w in ctx]])
        logits = model(ids)
        probs = F.softmax(logits, dim=-1)
        top3 = probs.topk(3, dim=-1)
        preds = [(id2word[i.item()], p.item()) for i, p in zip(top3.indices[0], top3.values[0])]
        pred_str = ", ".join(f"{w}({p:.0%})" for w, p in preds)
        print(f"  {ctx[0]} {ctx[1]} → {pred_str}")

## Generalization Through Embeddings

The whole point of using a neural network: similar words produce similar predictions automatically. Let's visualize the learned embeddings to see *why* — similar words ended up close together in embedding space.

In [None]:
# Extract and visualize the learned embeddings (project to 2D with PCA)
with torch.no_grad():
    embeddings = model.embedding.weight.numpy()

# Simple 2D projection using SVD (like PCA)
centered = embeddings - embeddings.mean(axis=0)
U, S, Vt = np.linalg.svd(centered, full_matrices=False)
coords = centered @ Vt[:2].T

plt.figure(figsize=(8, 6))
for i, word in enumerate(all_words):
    plt.plot(coords[i, 0], coords[i, 1], 'o', markersize=10)
    plt.annotate(word, (coords[i, 0] + 0.02, coords[i, 1] + 0.02), fontsize=11)

plt.title("Learned embeddings (projected to 2D)\nSimilar words cluster together")
plt.xlabel("Dimension 1")
plt.ylabel("Dimension 2")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Animals (cat, dog, bird) should cluster. People (girl, boy) should cluster.")
print("The network learned these groupings automatically — nobody told it cats and dogs are similar!")

## Temperature Sampling

When generating text, we don't always want the most likely word — that makes boring, repetitive text. **Temperature** controls randomness: low temperature = safe and predictable, high temperature = creative and wild.

In [None]:
def predict_with_temperature(model, context, temperature=1.0):
    """Sample a next word using temperature-scaled probabilities."""
    ids = torch.tensor([[word2id[w] for w in context]])
    with torch.no_grad():
        logits = model(ids)
        # Divide logits by temperature before softmax
        scaled = logits / max(temperature, 0.01)
        probs = F.softmax(scaled, dim=-1).squeeze()
    # Sample from the distribution
    idx = torch.multinomial(probs, 1).item()
    return id2word[idx]

# Show how temperature affects predictions
context = ["the", "cat"]
for temp in [0.1, 0.5, 1.0, 2.0]:
    samples = [predict_with_temperature(model, context, temp) for _ in range(20)]
    counts = Counter(samples)
    dist = ", ".join(f"{w}: {c}" for w, c in counts.most_common())
    print(f"Temperature {temp:.1f}: {dist}")

print("\nLow temp → always picks the top word. High temp → more variety.")

---

*This notebook accompanies [Chapter 6: Predicting the Next Word](https://robennals.github.io/ai-explained/06-next-word-prediction). The interactive widgets in the web version let you explore these concepts visually — including a real neural network trained on 50,000 children's stories.*