# Character-Level LSTM for The Office (US) Dialogue

This notebook trains a character-level LSTM language model to generate script-style dialogue in the format:

```
PERSON: Dialogue
```

It is fully runnable in Google Colab and will use GPU if available.

In [1]:
import math
import os
import random
import re
import time
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    torch.backends.cudnn.benchmark = True
print(f"Device: {DEVICE}")

Device: cuda


In [None]:
DATA_PATH = Path("office_script_clean.txt")
CLEANED_PATH = Path("office_script_cleaned.txt")

if not DATA_PATH.exists():
    print("office_script_clean.txt not found in current directory.")
    print("\nTo upload in Colab:")
    print("  1. Click the folder icon in the left sidebar")
    print("  2. Click the upload icon")
    print("  3. Select office_script_clean.txt")
    print("  4. Wait for upload to complete, then rerun this cell")
    print("\nOr drag and drop the file into the Colab file browser")
    raise FileNotFoundError(
        f"office_script_clean.txt not found at {DATA_PATH.absolute()}. "
        "Please upload it using Colab's Files sidebar (see instructions above) and rerun this cell."
    )

print(f"Found {DATA_PATH}")
raw_text = DATA_PATH.read_text(encoding="utf-8")
print(f"Loaded {len(raw_text):,} characters from dataset")

def clean_line(line: str) -> str:
    line = line.replace("\r", "")
    line = re.sub(r"<[^>]+>", "", line)
    line = re.sub(r"^\s*(\d{1,2}:\d{2}|\d{1,2}:\d{2}:\d{2})\s*", "", line)
    line = re.sub(r"\[[^\]]*\]", "", line)
    line = re.sub(r"\([^\)]*\)", "", line)
    if ":" not in line:
        return ""
    speaker, dialogue = line.split(":", 1)
    speaker = speaker.strip().upper()
    dialogue = dialogue.strip()
    if not speaker or not dialogue:
        return ""
    dialogue = re.sub(r"\s+([,.!?;:])", r"\1", dialogue)
    dialogue = re.sub(r"\s{2,}", " ", dialogue)
    return f"{speaker}: {dialogue}"

lines = []
for raw_line in raw_text.split("\n"):
    cleaned = clean_line(raw_line)
    if cleaned:
        lines.append(cleaned)

ADD_SCENE_SEPARATOR = False
if ADD_SCENE_SEPARATOR:
    corpus_text = "\n\n===\n\n".join(lines)
else:
    corpus_text = "\n".join(lines)

CLEANED_PATH.write_text(corpus_text, encoding="utf-8")
print(f"Lines kept: {len(lines)}")
print(f"Cleaned corpus saved to: {CLEANED_PATH}")

In [None]:
text = CLEANED_PATH.read_text(encoding="utf-8")

chars = sorted(set(text))
char_to_id = {ch: i for i, ch in enumerate(chars)}
id_to_char = {i: ch for ch, i in char_to_id.items()}

encoded = torch.tensor([char_to_id[ch] for ch in text], dtype=torch.long)

split_idx = int(0.9 * len(encoded))
train_data = encoded[:split_idx]
val_data = encoded[split_idx:]

CONTEXT_LEN = 120

print(f"Vocabulary size: {len(chars)}")
print(f"Total characters: {len(encoded)}")
print(f"Context window length: {CONTEXT_LEN}")
print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")
print(f"Device: {DEVICE}")

In [None]:
class RandomWindowBatchDataset(IterableDataset):
    def __init__(self, data: torch.Tensor, context_len: int, batch_size: int, steps_per_epoch: int):
        self.data = data
        self.context_len = context_len
        self.batch_size = batch_size
        self.steps_per_epoch = steps_per_epoch

    def __len__(self) -> int:
        return self.steps_per_epoch

    def __iter__(self):
        data = self.data
        device = data.device
        max_start = len(data) - self.context_len - 1
        offsets = torch.arange(self.context_len, device=device)
        for _ in range(self.steps_per_epoch):
            starts = torch.randint(0, max_start, (self.batch_size,), device=device)
            idx = starts[:, None] + offsets[None, :]
            x = data[idx]
            y = data[idx + 1]
            yield x, y


class SequentialWindowBatchDataset(IterableDataset):
    def __init__(self, data: torch.Tensor, context_len: int, batch_size: int, steps: int):
        self.data = data
        self.context_len = context_len
        self.batch_size = batch_size
        self.steps = steps

    def __len__(self) -> int:
        return self.steps

    def __iter__(self):
        data = self.data
        device = data.device
        max_start = len(data) - self.context_len - 1
        total = self.steps * self.batch_size
        offsets = torch.arange(self.context_len, device=device)
        for start in range(0, total, self.batch_size):
            starts = torch.arange(start, start + self.batch_size, device=device)
            idx = starts[:, None] + offsets[None, :]
            x = data[idx]
            y = data[idx + 1]
            yield x, y


BATCH_SIZE = 128

if DEVICE.type == "cuda":
    train_data_device = train_data.to(DEVICE)
    val_data_device = val_data.to(DEVICE)
else:
    train_data_device = train_data
    val_data_device = val_data

max_start_train = len(train_data_device) - CONTEXT_LEN - 1
max_start_val = len(val_data_device) - CONTEXT_LEN - 1
train_steps = max_start_train // BATCH_SIZE
val_steps = max_start_val // BATCH_SIZE

train_dataset = RandomWindowBatchDataset(train_data_device, CONTEXT_LEN, BATCH_SIZE, train_steps)
val_dataset = SequentialWindowBatchDataset(val_data_device, CONTEXT_LEN, BATCH_SIZE, val_steps)

train_loader = DataLoader(train_dataset, batch_size=None)
val_loader = DataLoader(val_dataset, batch_size=None)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
class CharLSTM(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int, hidden_dim: int, num_layers: int, dropout: float):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.lstm(x, hidden)
        logits = self.fc(out)
        return logits, hidden

EMBED_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2
DROPOUT = 0.25

model = CharLSTM(
    vocab_size=len(chars),
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
).to(DEVICE)

print(model)

In [None]:
EPOCHS = 15
LEARNING_RATE = 2e-3
CLIP_NORM = 1.0

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

use_amp = DEVICE.type == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

train_losses = []
val_losses = []
train_ppls = []
val_ppls = []

for epoch in range(1, EPOCHS + 1):
    epoch_start = time.time()
    model.train()
    total_train_loss = 0.0
    for x, y in train_loader:
        x = x.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits, _ = model(x)
            loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_ppl = math.exp(avg_train_loss)
    train_ppls.append(train_ppl)

    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(DEVICE, non_blocking=True)
            y = y.to(DEVICE, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=use_amp):
                logits, _ = model(x)
                loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
            total_val_loss += loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    val_ppl = math.exp(avg_val_loss)
    val_ppls.append(val_ppl)
    epoch_time = time.time() - epoch_start

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {avg_train_loss:.4f} | train ppl {train_ppl:.4f} | "
        f"val loss {avg_val_loss:.4f} | val ppl {val_ppl:.4f} | "
        f"time {epoch_time:.1f}s",
        flush=True,
    )

print(f"Final validation perplexity: {val_ppls[-1]:.4f}")
print("Expected final validation perplexity range: 2.5-3.5")

In [None]:
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label="Train loss")
plt.plot(val_losses, label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("Cross-entropy loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_ppls, label="Train perplexity")
plt.plot(val_ppls, label="Val perplexity")
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
def generate_script(seed_text, temperature, num_tokens_to_generate):
    if temperature <= 0:
        raise ValueError("temperature must be > 0")
    for ch in seed_text:
        if ch not in char_to_id:
            raise ValueError(f"Character not in vocab: {repr(ch)}")

    model.eval()
    generated = list(seed_text)

    input_ids = torch.tensor([char_to_id[ch] for ch in seed_text], device=DEVICE).unsqueeze(0)
    hidden = None

    with torch.no_grad():
        _, hidden = model(input_ids, hidden)
        current_id = input_ids[:, -1:]

        for _ in range(num_tokens_to_generate):
            logits, hidden = model(current_id, hidden)
            logits = logits[:, -1, :] / temperature
            probs = torch.softmax(logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            next_char = id_to_char[next_id.item()]
            generated.append(next_char)
            current_id = next_id

    return "".join(generated)

In [None]:
seed_text = "MICHAEL: I have an idea for today's meeting.\nJIM: "

samples = {}
for temp in [0.3, 0.7, 1.0]:
    samples[temp] = generate_script(
        seed_text=seed_text,
        temperature=temp,
        num_tokens_to_generate=400,
    )
    print("\n" + "=" * 80)
    print(f"Temperature: {temp}")
    print(samples[temp])

best_temp = 0.7
best_sample = samples[best_temp]
print("\n" + "=" * 80)
print("Best sample")
print(f"Seed text: {seed_text}")
print(f"Temperature: {best_temp}")
print(f"Validation perplexity: {val_ppls[-1]:.4f}")
print(best_sample)

## Temperature Comparison Notes

- **Coherence:** 0.3 is usually most coherent and on-format; 1.0 is most likely to drift.
- **Creativity:** 1.0 tends to be more surprising/quirky; 0.7 balances novelty and structure.
- **Repetition:** 0.3 can repeat phrases; 1.0 repeats less but can be noisier.
- **Grammatical stability:** 0.3 is most stable, 0.7 acceptable, 1.0 may fragment.

After running the notebook, update these notes if your observed samples differ.