In [None]:
import random
from typing import List

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from tqdm.autonotebook import tqdm

%matplotlib inline

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
words = open("../../data/names.txt", "r").read().splitlines()
words[:8]

In [None]:
len(words)

In [None]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set("".join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}
print(itos)

In [None]:
class WordTokensDataset(Dataset):
    def __init__(self, words: List[str], block_size: int):
        X, Y = [], []
        for w in words:

            context = [0] * block_size
            for ch in w + ".":
                ix = stoi[ch]
                X.append(context)
                Y.append(ix)
                context = context[1:] + [ix]  # crop and append

        X = torch.tensor(X)
        Y = torch.tensor(Y)

        self.X = X
        self.Y = Y

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]

    def __len__(self):
        return self.X.shape[0]

In [None]:
random.seed(42)
block_size = 3

random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))

train_dataset = WordTokensDataset(words[:n1], block_size)
validation_dataset = WordTokensDataset(words[n1:n2], block_size)
test_dataset = WordTokensDataset(words[n2:], block_size)

In [None]:
train_dataset.X.shape, train_dataset.Y.shape

In [None]:
g = torch.Generator().manual_seed(2147483647)  # for reproducibility
C = torch.randn((27, 10), generator=g).to(device)
W1 = torch.randn((30, 200), generator=g).to(device)
b1 = torch.randn(200, generator=g).to(device)
W2 = torch.randn((200, 27), generator=g).to(device)
b2 = torch.randn(27, generator=g).to(device)
parameters = [C, W1, b1, W2, b2]

In [None]:
sum(p.nelement() for p in parameters)  # number of parameters in total

In [None]:
for p in parameters:
    p.requires_grad = True

In [None]:
BATCH_SIZE = 512 if device == torch.device("cuda") else 32
TOTAL_SAMPLES_TO_TRAIN = 40_000_000 if device == torch.device("cuda") else 5_000_000
EPOCHS = TOTAL_SAMPLES_TO_TRAIN // len(train_dataset)

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True
)

lossi = []
stepi = []


for epoch in tqdm(range(EPOCHS)):

    for X_batch, Y_batch in train_dataloader:

        X_batch = X_batch.to(device)
        Y_batch = Y_batch.to(device)

        # forward pass
        emb = C[X_batch]  # (BATCH_SIZE, 3, 10)
        h = torch.tanh(emb.view(-1, 30) @ W1 + b1)  # (BATCH_SIZE, 200)
        logits = h @ W2 + b2  # (BATCH_SIZE, 27)
        loss = F.cross_entropy(logits, Y_batch)
        # print(loss.item())

        # backward pass
        for p in parameters:
            p.grad = None
        loss.backward()

        # update
        lr = 0.1 if epoch < EPOCHS * 0.5 else 0.01
        for p in parameters:
            p.data += -lr * p.grad

    # track stats
    stepi.append(epoch)
    lossi.append(loss.log10().item())

In [None]:
plt.plot(stepi, lossi)

In [None]:
def calculate_loss(dataset: WordTokensDataset):
    X, Y = dataset.X, dataset.Y
    X = X.to(device)
    Y = Y.to(device)
    emb = C[X]
    h = torch.tanh(emb.view(-1, 30) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    return loss.item()

In [None]:
training_loss = calculate_loss(train_dataset)
validation_loss = calculate_loss(validation_dataset)
print(f"{training_loss = :4f}, {validation_loss = :4f}")

In [None]:
# visualize dimensions 0 and 1 of the embedding matrix C for all characters
C_cpu = C.cpu()

plt.figure(figsize=(8, 8))
plt.scatter(C_cpu[:, 0].data, C_cpu[:, 1].data, s=200)
for i in range(C_cpu.shape[0]):
    plt.text(
        C_cpu[i, 0].item(),
        C_cpu[i, 1].item(),
        itos[i],
        ha="center",
        va="center",
        color="white",
    )
plt.grid("minor")

In [None]:
# sample from the model
g = torch.Generator(device).manual_seed(2147483647 + 10)

for _ in range(20):

    out = []
    context = [0] * block_size  # initialize with all ...
    while True:
        emb = C[torch.tensor([context])]  # (1,block_size,d)
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break

    print("".join(itos[i] for i in out))