In [1]:
import numpy as np

import json
import math
import os
import re
from collections import Counter
from typing import Any, List, Optional, Tuple, Dict
from datetime import datetime
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
DEVICE = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 10

In [3]:
def simple_tokenize(text: str) -> List[str]:
    text = text.lower()
    text = re.sub(r"([.,!?;:()\"'])", r" \1 ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text.split()


def load_poems() -> str:
    with open("poems.txt", "r", encoding="utf-8") as f:
        return f.read()


def build_vocab(tokens: List[str], min_freq: int = 1) -> Tuple[Dict[str, int], Dict[int, str]]:
    counts = Counter(tokens)
    vocab = ["<pad>", "<unk>", "<bos>", "<eos>"]
    for w, c in counts.items():
        if c >= min_freq and w not in vocab:
            vocab.append(w)
    stoi = {w: i for i, w in enumerate(vocab)}
    itos = {i: w for w, i in stoi.items()}
    return stoi, itos


def tokens_to_ids(tokens: List[str], stoi: Dict[str, int]) -> List[int]:
    unk = stoi["<unk>"]
    return [stoi.get(t, unk) for t in tokens]


def make_sequences(ids: List[int], seq_len: int) -> Tuple[List[List[int]], List[List[int]]]:
    X, Y = [], []
    for i in range(0, len(ids) - seq_len):
        x = ids[i:i+seq_len]
        y = ids[i+1:i+seq_len+1]
        X.append(x)
        Y.append(y)
    return X, Y


def text_quality_metrics(text: str) -> Dict[str, Any]:
    toks = simple_tokenize(text)
    if not toks:
        return {
            "token_count": 0,
            "unique_tokens": 0,
            "unique_ratio": 0.0,
            "repeat_2gram_ratio": 0.0,
            "repeat_3gram_ratio": 0.0,
            "top_tokens": [],
        }

    counts = Counter(toks)
    token_count = len(toks)
    unique_tokens = len(counts)
    unique_ratio = unique_tokens / max(token_count, 1)

    def repeat_ngram_ratio(n: int) -> float:
        if len(toks) < n:
            return 0.0
        ngrams = list(zip(*[toks[i:] for i in range(n)]))
        c = Counter(ngrams)
        repeats = sum(v - 1 for v in c.values() if v > 1)
        return repeats / max(len(ngrams), 1)

    return {
        "token_count": int(token_count),
        "unique_tokens": int(unique_tokens),
        "unique_ratio": float(unique_ratio),
        "repeat_2gram_ratio": float(repeat_ngram_ratio(2)),
        "repeat_3gram_ratio": float(repeat_ngram_ratio(3)),
        "top_tokens": [(w, int(c)) for w, c in counts.most_common(10)],
    }


In [4]:
def softmax(x):
    x = x - np.max(x)
    e = np.exp(x)
    return e / np.sum(e)


def one_hot(idx, V):
    v = np.zeros((V, 1))
    v[idx] = 1.0
    return v


class ScratchRNN:
    def __init__(self, vocab_size, hidden_size=64, lr=1e-2, seed=42):
        rng = np.random.default_rng(seed)
        self.V = vocab_size
        self.H = hidden_size
        self.lr = lr

        self.Wxh = rng.normal(0, 0.01, (self.H, self.V))
        self.Whh = rng.normal(0, 0.01, (self.H, self.H))
        self.Why = rng.normal(0, 0.01, (self.V, self.H))
        self.bh = np.zeros((self.H, 1))
        self.by = np.zeros((self.V, 1))

    def forward(self, inputs, hprev):
        xs, hs, ys, ps = {}, {}, {}, {}
        hs[-1] = hprev

        for t, idx in enumerate(inputs):
            xs[t] = one_hot(idx, self.V)
            hs[t] = np.tanh(self.Wxh @ xs[t] + self.Whh @ hs[t-1] + self.bh)
            ys[t] = self.Why @ hs[t] + self.by
            ps[t] = softmax(ys[t].ravel()).reshape(-1, 1)
        return xs, hs, ps

    def loss_and_grads(self, inputs, targets, hprev):
        xs, hs, ps = self.forward(inputs, hprev)

        loss = 0.0
        for t in range(len(inputs)):
            loss += -np.log(ps[t][targets[t], 0] + 1e-12)

        dWxh = np.zeros_like(self.Wxh)
        dWhh = np.zeros_like(self.Whh)
        dWhy = np.zeros_like(self.Why)
        dbh = np.zeros_like(self.bh)
        dby = np.zeros_like(self.by)

        dhnext = np.zeros((self.H, 1))

        for t in reversed(range(len(inputs))):
            dy = ps[t].copy()

            dy[targets[t]] -= 1.0
            dWhy += dy @ hs[t].T
            dby += dy

            dh = self.Why.T @ dy + dhnext
            dhraw = (1 - hs[t] * hs[t]) * dh
            dbh += dhraw
            dWxh += dhraw @ xs[t].T
            dWhh += dhraw @ hs[t-1].T
            dhnext = self.Whh.T @ dhraw

        for d in [dWxh, dWhh, dWhy, dbh, dby]:
            np.clip(d, -5, 5, out=d)

        hlast = hs[len(inputs)-1]
        return loss, (dWxh, dWhh, dWhy, dbh, dby), hlast

    def step(self, grads):
        dWxh, dWhh, dWhy, dbh, dby = grads
        self.Wxh -= self.lr * dWxh
        self.Whh -= self.lr * dWhh
        self.Why -= self.lr * dWhy
        self.bh -= self.lr * dbh
        self.by -= self.lr * dby

    def sample(self, start_idx, itos, length=30, temperature=1.0):
        h = np.zeros((self.H, 1))
        x = one_hot(start_idx, self.V)
        out = []

        for _ in range(length):
            h = np.tanh(self.Wxh @ x + self.Whh @ h + self.bh)
            y = self.Why @ h + self.by
            p = softmax((y.ravel() / max(temperature, 1e-6)))
            idx = np.random.choice(range(self.V), p=p)
            out.append(itos[idx])
            x = one_hot(idx, self.V)
        return " ".join(out)


def rnn_main_short_run():
    text = load_poems()
    tokens = ["<bos>"] + simple_tokenize(text) + ["<eos>"]
    stoi, itos = build_vocab(tokens, min_freq=1)
    ids = tokens_to_ids(tokens, stoi)

    rnn = ScratchRNN(vocab_size=len(stoi), hidden_size=128, lr=0.05)
    seq_len = 25
    h = np.zeros((rnn.H, 1))

    epoch_times = []
    samples = []

    for epoch in range(n_epochs):
        t0 = time.perf_counter()
        total_loss = 0.0
        n = 0
        for i in range(0, len(ids) - seq_len - 1, seq_len):
            inp = ids[i:i+seq_len]
            tgt = ids[i+1:i+seq_len+1]
            loss, grads, h = rnn.loss_and_grads(inp, tgt, h)
            rnn.step(grads)
            total_loss += loss
            n += 1

        avg = total_loss / max(n, 1)

        sample = rnn.sample(stoi["<bos>"], itos, length=30, temperature=0.9)
        samples.append(sample)

        t1 = time.perf_counter()
        epoch_times.append(float(t1 - t0))

        print(f"Epoch {epoch+1} | avg loss: {avg:.4f}")
        print("Sample:", sample)
        print(f"Epoch time: {epoch_times[-1]:.3f}s")

    return epoch_times, samples[-1]

scratch_rnn_epoch_times, scratch_rnn_last_sample = rnn_main_short_run()


Epoch 1 | avg loss: 301.3461
Sample: masts : in our my in our my in forgetting this in emma this in emma this in sea this may emma ' in when ' seen our this in
Epoch time: 85.784s
Epoch 2 | avg loss: 323.4304
Sample: pains tis our emma and wild our emma are wild our emma are wild our emma are wild our emma are wild our emma are wild our emma are wild
Epoch time: 87.461s
Epoch 3 | avg loss: 367.3105
Sample: days death ' wild our green and wild ; green and wild our green this wild our green this wild our green this wild our green and wild our emma
Epoch time: 94.385s
Epoch 4 | avg loss: 365.7155
Sample: tea about our , and name of , and call of , seen myself of , and name of , and told are , and call to , seen name
Epoch time: 88.021s
Epoch 5 | avg loss: 366.0527
Sample: high side , we they course had thee it have , we it have , we sky have , we it have , thee they , had name it have
Epoch time: 96.653s
Epoch 6 | avg loss: 364.6420
Sample: lance name swift , seen wild they , seen wild th

In [5]:
class SeqDatasetOneHot(Dataset):
    def __init__(self, X, Y, vocab_size):
        self.X = torch.tensor(X, dtype=torch.long)
        self.Y = torch.tensor(Y, dtype=torch.long)
        self.V = vocab_size

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        x_ids = self.X[idx]
        y_ids = self.Y[idx]

        x_oh = torch.zeros(x_ids.size(0), self.V, dtype=torch.float32)
        x_oh.scatter_(1, x_ids.unsqueeze(1), 1.0)
        return x_oh, y_ids


class OneHotRNNLM(nn.Module):
    def __init__(self, vocab_size, hidden=256):
        super().__init__()
        self.rnn = nn.RNN(input_size=vocab_size,
                          hidden_size=hidden, batch_first=True)
        self.fc = nn.Linear(hidden, vocab_size)

    def forward(self, x_oh, h0=None):
        out, hn = self.rnn(x_oh, h0)
        logits = self.fc(out)
        return logits, hn


@torch.no_grad()
def generate(model, stoi, itos, seed_text="<bos>", max_new=40, temperature=1.0):
    model.eval()
    tokens = seed_text.split()
    ids = [stoi.get(t, stoi["<unk>"]) for t in tokens]
    V = len(stoi)

    h = None
    for _ in range(max_new):
        x = torch.tensor(ids[-1:], dtype=torch.long,
                         device=DEVICE)  # last token
        x_oh = torch.zeros(1, 1, V, device=DEVICE)
        x_oh.scatter_(2, x.view(1, 1, 1), 1.0)

        logits, h = model(x_oh, h)
        next_logits = logits[0, -1] / max(temperature, 1e-6)
        probs = torch.softmax(next_logits, dim=0)
        nxt = torch.multinomial(probs, 1).item()
        ids.append(nxt)

    words = [itos[i] for i in ids]
    return " ".join(words)


def rnn_onehot_main():
    text = load_poems()
    tokens = ["<bos>"] + simple_tokenize(text) + ["<eos>"]
    stoi, itos = build_vocab(tokens, min_freq=1)
    ids = tokens_to_ids(tokens, stoi)

    seq_len = 25
    X, Y = make_sequences(ids, seq_len)
    ds = SeqDatasetOneHot(X, Y, vocab_size=len(stoi))
    dl = DataLoader(ds, batch_size=64, shuffle=True, drop_last=True)

    model = OneHotRNNLM(vocab_size=len(stoi), hidden=256).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    print("Training One-Hot RNN on", DEVICE)

    epoch_times = []
    samples = []

    for epoch in range(n_epochs):
        model.train()
        t0 = time.perf_counter()
        total = 0.0
        steps = 0
        for x_oh, y in dl:
            x_oh = x_oh.to(DEVICE)
            y = y.to(DEVICE)

            logits, _ = model(x_oh)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += loss.item()
            steps += 1

        print(f"Epoch {epoch+1} | loss: {total/steps:.4f}")

        sample = generate(model, stoi, itos, seed_text="<bos>", max_new=40, temperature=0.9)
        samples.append(sample)

        t1 = time.perf_counter()
        epoch_times.append(float(t1 - t0))

        print("Sample:", sample)
        print(f"Epoch time: {epoch_times[-1]:.3f}s")

    print(f"Total training time (one-hot): {sum(epoch_times):.2f}s")
    return epoch_times, samples[-1]

onehot_rnn_epoch_times, onehot_rnn_last_sample = rnn_onehot_main()


Training One-Hot RNN on mps
Epoch 1 | loss: 6.2165
Sample: <bos> lungs . long i rich be watches . borne with the lift of balanced , and or trouble , i who , house antique to ? about the pen and first . ? d that is bride it more dissatisfied
Epoch time: 25.517s
Epoch 2 | loss: 5.0680
Sample: <bos> noon about -- ! so deny , their of circle following or rest the mine chemist when before emanations well , to give our white i ooze for the loves , you trusted i , i ' the little ,
Epoch time: 23.657s
Epoch 3 | loss: 3.9790
Sample: <bos> are your object with her between us , who is not , now we thousand come trouble than others it is limitless to stronger upon ! retreat you die for ! i love , the sign ' d by the
Epoch time: 23.614s
Epoch 4 | loss: 2.9035
Sample: <bos> to nothing , these thirty i know --great a knoll ago , no heart farewell no less part , and what is not ? never might ride . . " " o heart , i am friend powers ? )
Epoch time: 23.741s
Epoch 5 | loss: 1.8531
Sample: <bos> of the e

In [6]:
class SeqDatasetIdx(Dataset):
    def __init__(self, X, Y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.Y = torch.tensor(Y, dtype=torch.long)

    def __len__(self):
        return self.X.size(0)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


class EmbRNNLM(nn.Module):
    def __init__(self, vocab_size, emb=128, hidden=256):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb)
        self.rnn = nn.RNN(input_size=emb, hidden_size=hidden, batch_first=True)
        self.fc = nn.Linear(hidden, vocab_size)

    def forward(self, x_ids, h0=None):
        x = self.emb(x_ids)
        out, hn = self.rnn(x, h0)
        logits = self.fc(out)
        return logits, hn


@torch.no_grad()
def generate(model, stoi, itos, seed_text="<bos>", max_new=40, temperature=1.0):
    model.eval()
    tokens = seed_text.split()
    ids = [stoi.get(t, stoi["<unk>"]) for t in tokens]

    h = None
    for _ in range(max_new):
        x = torch.tensor([[ids[-1]]], dtype=torch.long, device=DEVICE)
        logits, h = model(x, h)
        next_logits = logits[0, -1] / max(temperature, 1e-6)
        probs = torch.softmax(next_logits, dim=0)
        nxt = torch.multinomial(probs, 1).item()
        ids.append(nxt)

    return " ".join(itos[i] for i in ids)


def rnn_embedding_main():
    text = load_poems()
    tokens = ["<bos>"] + simple_tokenize(text) + ["<eos>"]
    stoi, itos = build_vocab(tokens, min_freq=1)
    ids = tokens_to_ids(tokens, stoi)

    seq_len = 25
    X, Y = make_sequences(ids, seq_len)
    ds = SeqDatasetIdx(X, Y)
    dl = DataLoader(ds, batch_size=64, shuffle=True, drop_last=True)

    model = EmbRNNLM(vocab_size=len(stoi), emb=128, hidden=256).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.CrossEntropyLoss()

    print("Training Embedding RNN on", DEVICE)

    epoch_times = []
    samples = []

    for epoch in range(n_epochs):
        model.train()
        t0 = time.perf_counter()
        total = 0.0
        steps = 0
        for x_ids, y in dl:
            x_ids = x_ids.to(DEVICE)
            y = y.to(DEVICE)

            logits, _ = model(x_ids)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += loss.item()
            steps += 1

        print(f"Epoch {epoch+1} | loss: {total/steps:.4f}")

        sample = generate(model, stoi, itos, seed_text="<bos>", max_new=40, temperature=0.9)
        samples.append(sample)

        t1 = time.perf_counter()
        epoch_times.append(float(t1 - t0))

        print("Sample:", sample)
        print(f"Epoch time: {epoch_times[-1]:.3f}s")

    print(f"Total training time (embedding): {sum(epoch_times):.2f}s")
    return epoch_times, samples[-1]

embedding_rnn_epoch_times, embedding_rnn_last_sample = rnn_embedding_main()

Training Embedding RNN on mps
Epoch 1 | loss: 5.0160
Sample: <bos> dazzle , theology—but ' things them dim-descried ' s best , but i know , perhaps what is his brother at the wife with him on brotherly and invite , not one is more far hot made ; and am
Epoch time: 14.533s
Epoch 2 | loss: 2.4822
Sample: <bos> sat and straw , over the long-leav ' d slave is lit with the iris sheen of the north , where the voice of enjoyment picks ' d , when thou art gone , my dear , i find him
Epoch time: 14.467s
Epoch 3 | loss: 1.1327
Sample: <bos> sullen , one then stand on a heap ' d tale of nature ' s immortality , a venerable thing ? of hay the best , or through those drain ' d from your eyes , my burst rest in
Epoch time: 14.395s
Epoch 4 | loss: 0.6131
Sample: <bos> text and come , with thee , thou : white peacocks , songs at eve , and antique maps of america . farewell at know that i love thee better . i could not die – here we go to
Epoch time: 14.004s
Epoch 5 | loss: 0.4282
Sample: <bos> text yo

In [7]:
total_scratch_rnn_time = sum(scratch_rnn_epoch_times)
total_onehot_rnn_time = sum(onehot_rnn_epoch_times)
total_embedding_rnn_time = sum(embedding_rnn_epoch_times)

print(f"Total training time (Scratch RNN): {total_scratch_rnn_time:.2f}s")
print(f"Total training time (One-Hot RNN): {total_onehot_rnn_time:.2f}s")
print(f"Total training time (Embedding RNN): {total_embedding_rnn_time:.2f}s")

Total training time (Scratch RNN): 848.27s
Total training time (One-Hot RNN): 245.56s
Total training time (Embedding RNN): 136.47s


In [8]:
scratch_rnn_metrics = text_quality_metrics(scratch_rnn_last_sample)
onehot_rnn_metrics = text_quality_metrics(onehot_rnn_last_sample)
embedding_rnn_metrics = text_quality_metrics(embedding_rnn_last_sample)

print("Scratch RNN Metrics:")
print(scratch_rnn_metrics)
print("\nOne-Hot RNN Metrics:")
print(onehot_rnn_metrics)
print("\nEmbedding RNN Metrics:")
print(embedding_rnn_metrics)

Scratch RNN Metrics:
{'token_count': 30, 'unique_tokens': 9, 'unique_ratio': 0.3, 'repeat_2gram_ratio': 0.5862068965517241, 'repeat_3gram_ratio': 0.4642857142857143, 'top_tokens': [('and', 7), ('our', 6), ('wild', 6), ('green', 5), ('of', 2), ('agony', 1), ('dear', 1), ('least', 1), ('they', 1)]}

One-Hot RNN Metrics:
{'token_count': 41, 'unique_tokens': 34, 'unique_ratio': 0.8292682926829268, 'repeat_2gram_ratio': 0.025, 'repeat_3gram_ratio': 0.0, 'top_tokens': [('the', 2), ('!', 2), ('shall', 2), ("'", 2), ('d', 2), ('yet', 2), ('and', 2), ('<bos>', 1), ('fords', 1), ('park', 1)]}

Embedding RNN Metrics:
{'token_count': 41, 'unique_tokens': 30, 'unique_ratio': 0.7317073170731707, 'repeat_2gram_ratio': 0.0, 'repeat_3gram_ratio': 0.0, 'top_tokens': [('the', 7), (',', 4), (';', 2), ('his', 2), ('<bos>', 1), ('text', 1), ('out', 1), ('from', 1), ('crowd', 1), ('steps', 1)]}
