In [None]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


In [None]:
import itertools as it
import time
from typing import Optional

import matplotlib.pyplot as plt
import torch
from torch import Tensor, nn
from torch.nn import functional as F

In [None]:
mps = torch.device("mps")

In [None]:
# Input data set
with open("shakespeare.txt") as f:
    text = f.read()
print(f"Input length = {len(text):,d} characters")


In [None]:
print(text[:250])


In [None]:
# Vocabulary (NB: We aren't covering all ASCII)
chars = sorted(set(text))
vocab_size = len(chars)
print("".join(chars), f"(N={vocab_size})")


In [None]:
# Tokenization
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for i, c in enumerate(chars)}


def encode(s: str) -> int:
    return [stoi[c] for c in s]


def decode(i: int) -> str:
    return "".join([itos[n] for n in i])


In [None]:
print(encode("Hello there"))
print(decode(encode("Hello there")))


In [None]:
# Into PyTorch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)


In [None]:
# Train and test sets
nsplit = int(0.9 * len(data))
train, test = data[:nsplit], data[nsplit:]


In [None]:
block_size = 8
train[: block_size + 1], train[1 : block_size + 2]


In [None]:
# Example of context sizes
x = train[:block_size]
y = train[1 : 1 + block_size]

for t in range(block_size):
    print(f"target = {y[t]:>2d}, input = {x[:t+1]}")


In [None]:
# Generate a batch
def batch(t: torch.Tensor, batch_size: int, block_size: int) -> torch.Tensor:
    ix = torch.randint(len(t) - block_size, (batch_size,))
    x = torch.stack([t[i : i + block_size] for i in ix])
    y = torch.stack([t[i + 1 : i + 1 + block_size] for i in ix])
    return x, y


torch.manual_seed(1337)
xb, yb = batch(train, 4, 8)

print(f"Inputs ({xb.shape}):")
print(xb, "\n")

print(f"Targets ({yb.shape}):")
print(yb, "\n")

for b, t in it.product(range(4), range(8)):
    print(f"target = {yb[b,t]:>2d}, input = {xb[b,:t+1]}")


In [None]:
# Bigram language model
class BigramLM(nn.Module):
    def __init__(self, vocab_size: int) -> None:
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, vocab_size)

    def forward(
        self, x: Tensor, y: Optional[Tensor] = None
    ) -> tuple[Tensor, Optional[Tensor]]:
        """Make a prediction of the next token given the current."""
        logits = self.embeddings(x)  # shape: (batch, time, vocab)
        if y is None:
            return logits, None
        b, t, c = logits.shape
        logits = logits.view(b * t, c)
        loss = F.cross_entropy(logits, y.view(b * t))
        return logits, loss

    def generate(self, x: Tensor, max: int = 1) -> Tensor:
        """Take an input (B, T) and sample a new token."""
        for _ in range(max):
            # We feed the whole context for generality, though
            # the BigramLM only uses the final token.
            logits, _ = self(x)
            probs = F.softmax(logits[:, -1, :], dim=-1)
            newx = torch.multinomial(probs, 1)
            x = torch.cat((x, newx), -1)
        return x


In [None]:
torch.manual_seed(1337)

m = BigramLM(vocab_size)
logits, loss = m.forward(xb, yb)

# Example: Starting from newline, generate 100 tokens
print(logits.shape, loss)
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), 100)[0].tolist()))

In [None]:
class timer:
    def __enter__(self):
        self.time = time.perf_counter()

    def __exit__(self, *args, **kwargs):
        now = time.perf_counter()
        print(f"Timer = {now-self.time} sec")


In [None]:
# Train the bigram model now
# Big LR because a small model
m = BigramLM(vocab_size)
optimizer = torch.optim.Adam(m.parameters(), lr=1e-3)
losses = []

In [None]:
batch_size = 32
with timer():
    for step in range(10_000):
        xb, yb = batch(train, batch_size=batch_size, block_size=8)
        _, loss = m(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        losses += [loss.item()]
    print(f"Loss = {losses[-1]}")


In [None]:
plt.plot(range(len(losses)), losses)
plt.yscale("log")
plt.show()

In [None]:
print(decode(m.generate(torch.zeros((1,1), dtype=torch.long), 500)[0].tolist()))