In [1]:
from pathlib import Path

import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Create dataset

words = Path.cwd().joinpath("names.txt").read_text().splitlines()
end_tok = "."
chars = [end_tok] + sorted(list(set("".join(words))))
char_to_index = {char: index for index, char in enumerate(chars)}
n_chars = len(chars)

xs, ys = [], []
for w in words:
    word_chars = [end_tok] + list(w) + [end_tok]
    for ch1, ch2 in zip(word_chars, word_chars[1:]):
        ix1 = char_to_index[ch1]
        ix2 = char_to_index[ch2]
        xs.append(ix1)
        ys.append(ix2)


xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
W = torch.randn((n_chars, n_chars), requires_grad=True)

In [3]:
# Training

for _ in range(100):
    # Forward pass
    xenc = F.one_hot(xs, num_classes=n_chars).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(dim=1, keepdim=True)
    loss = -probs[torch.arange(num), ys].log().mean()
    print(loss.item())

    # Backward pass
    W.grad = None
    loss.backward()
    W.data += -50 * W.grad

3.690117835998535
3.285710096359253
3.088467836380005
2.962078809738159
2.874926805496216
2.8120381832122803
2.7649848461151123
2.7286720275878906
2.699885606765747
2.676551580429077
2.657283067703247
2.6411068439483643
2.6273181438446045
2.615401268005371
2.6049752235412598
2.5957610607147217
2.5875484943389893
2.5801777839660645
2.5735247135162354
2.5674893856048584
2.561990976333618
2.55696177482605
2.5523464679718018
2.5480964183807373
2.5441715717315674
2.540537118911743
2.5371620655059814
2.5340216159820557
2.5310921669006348
2.5283539295196533
2.525789737701416
2.523383617401123
2.521122455596924
2.518993616104126
2.5169858932495117
2.515089511871338
2.5132956504821777
2.5115966796875
2.5099852085113525
2.5084540843963623
2.50699782371521
2.505610704421997
2.5042884349823
2.503026008605957
2.5018198490142822
2.5006656646728516
2.499561071395874
2.498502254486084
2.4974865913391113
2.496511459350586
2.495574712753296
2.494673728942871
2.493806838989258
2.4929723739624023
2.492168

In [273]:
# Sample from model

for _ in range(10):
    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes=n_chars).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(dim=1, keepdim=True)
        ix = torch.multinomial(input=probs, num_samples=1, replacement=True)
        char = chars[ix.item()]
        if char == end_tok:
            break
        out.append(char)
    print("".join(out))

brysein
zlamalartletona
khayaynana
jt
qun
ebran
ashrizrtaenahgu
ja
jaso
vcezinidha
