In [147]:
from pathlib import Path
from itertools import product

import torch
import torch.nn.functional as F

In [148]:
# Create dataset
from sklearn.model_selection import train_test_split

words = Path.cwd().joinpath("names.txt").read_text().splitlines()
end_tok = "."
chars = [end_tok] + sorted(list(set("".join(words))))
bigrams = [ch1 + ch2 for ch1, ch2 in product(chars, chars)]
char_to_index = {char: index for index, char in enumerate(chars)}
bigram_to_index = {bigram: index for index, bigram in enumerate(bigrams)}
n_chars = len(chars)
n_bigrams = len(bigrams)

xs, ys = [], []
for w in words:
    word_chars = [end_tok] + list(w) + [end_tok]
    for ch1, ch2, ch3 in zip(word_chars, word_chars[1:], word_chars[2:]):
        bigram = ch1 + ch2
        bigram_ix = bigram_to_index[bigram]
        char_ix = char_to_index[ch3]
        xs.append(bigram_ix)
        ys.append(char_ix)


xs, dev_xs, ys, dev_ys = train_test_split(xs, ys, test_size=.2)
dev_xs, test_xs, dev_ys, test_ys = train_test_split(dev_xs, dev_ys, test_size=.5)
xs = torch.tensor(xs)
dev_xs = torch.tensor(dev_xs)
test_xs = torch.tensor(test_xs)
ys = torch.tensor(ys)
dev_ys = torch.tensor(dev_ys)
test_ys = torch.tensor(test_ys)
num = xs.nelement()
xs.shape, dev_xs.shape

(torch.Size([156890]), torch.Size([19611]))

In [149]:
W = torch.randn((n_bigrams, n_chars), requires_grad=True)
W.shape

torch.Size([729, 27])

In [153]:
# Training

for _ in range(500):
    # Forward pass
    # xenc = F.one_hot(xs, num_classes=n_bigrams).float()
    # logits = xenc @ W
    logits = W[xs,:].reshape((num, 27))
    # counts = logits.exp()
    # probs = counts / counts.sum(dim=1, keepdim=True)
    # loss = -probs[torch.arange(num), ys].log().mean()
    loss = F.cross_entropy(logits, ys)
    print(loss.item())

    # Backward pass
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

2.1092753410339355
2.1092305183410645
2.1091861724853516
2.1091418266296387
2.109097480773926
2.109053134918213
2.109009265899658
2.1089649200439453
2.1089208126068115
2.108876943588257
2.108833074569702
2.1087892055511475
2.108745574951172
2.1087019443511963
2.1086583137512207
2.108614921569824
2.1085715293884277
2.1085283756256104
2.108484983444214
2.1084418296813965
2.108398914337158
2.108355760574341
2.1083128452301025
2.1082699298858643
2.108227014541626
2.108184576034546
2.1081418991088867
2.1080992221832275
2.1080567836761475
2.1080143451690674
2.1079721450805664
2.1079299449920654
2.1078875064849854
2.1078457832336426
2.1078038215637207
2.1077616214752197
2.107720136642456
2.107678174972534
2.1076366901397705
2.1075947284698486
2.107553482055664
2.1075119972229004
2.1074705123901367
2.107429027557373
2.1073880195617676
2.107346773147583
2.1073057651519775
2.107264757156372
2.1072237491607666
2.107182741165161
2.1071419715881348
2.1071012020111084
2.107060670852661
2.10701990127

In [154]:
# Sample from model
generator = torch.Generator().manual_seed(42)

for _ in range(10):
    out = ["."]
    ix = 0
    while True:
        # xenc = F.one_hot(torch.tensor([ix]), num_classes=n_bigrams).float()
        # logits = xenc @ W
        logits = W[ix].reshape((1, 27))
        counts = logits.exp()
        probs = counts / counts.sum(dim=1, keepdim=True)
        ix = torch.multinomial(input=probs, num_samples=1, replacement=True, generator=generator)
        char = chars[ix.item()]
        if char == end_tok:
            break
        out.append(char)
        bigram = "".join(out[len(out)-2:])
        #print(bigram)
        ix = bigram_to_index[bigram]

    print("".join(out))

.abbandyqace
.tren
.que
.pan
.aya
.daidanis
.ely
.imyre
.philandzella
.quinaessalyn


In [155]:
# Eval
xenc = F.one_hot(test_xs, num_classes=n_bigrams).float()
logits = xenc @ W
# counts = logits.exp()
# probs = counts / counts.sum(dim=1, keepdim=True)
# loss = -probs[torch.arange(test_xs.nelement()), test_ys].log().mean()
loss = F.cross_entropy(logits, test_ys)
loss

tensor(2.1289, grad_fn=<NllLossBackward0>)