In [22]:
# convert stoi and itos
NUM_CLASSES = 27

stoi = {}
itos = [None] * NUM_CLASSES

for i, s in enumerate('abcdefghijklmnopqrstuvwxyz'):
    itos[i+1] = s
    stoi[s] = i+1

itos[0] = '.'
stoi['.'] = 0

In [23]:
# read the words
words = open('../data/names.txt', 'r').read().splitlines()

In [24]:
len(words)

32033

In [25]:
import torch
import torch.nn.functional as F

In [26]:
# split into the training examples
xs, ys = [], []

for word in words:
    chs =  ['.'] + list(word) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]

        xs.append(ix1)
        ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)
n = len(xs)
n

228146

In [27]:
# initialise our weights and generator
g = torch.Generator().manual_seed(1)
W = torch.randn(27, 27, generator=g, requires_grad=True)

In [30]:
# train our model on the training examples
for k in range(50):
    # forward pass
    xenc = F.one_hot(xs, num_classes=27).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)

    loss = -probs[torch.arange(n), ys].log().mean() + 0.01 * (W**2).mean()

    W.grad = None
    loss.backward()
    print('loss', loss.item())

    W.data += -10 * W.grad


loss 2.5892045497894287
loss 2.5879828929901123
loss 2.5867843627929688
loss 2.585608720779419
loss 2.5844547748565674
loss 2.583322286605835
loss 2.5822107791900635
loss 2.581120014190674
loss 2.5800490379333496
loss 2.5789976119995117
loss 2.5779645442962646
loss 2.5769500732421875
loss 2.575953245162964
loss 2.574974298477173
loss 2.574012041091919
loss 2.573066473007202
loss 2.5721371173858643
loss 2.5712239742279053
loss 2.5703256130218506
loss 2.5694427490234375
loss 2.568574905395508
loss 2.567721128463745
loss 2.5668811798095703
loss 2.5660550594329834
loss 2.5652425289154053
loss 2.5644426345825195
loss 2.5636560916900635
loss 2.5628817081451416
loss 2.562119722366333
loss 2.5613694190979004
loss 2.5606307983398438
loss 2.559903621673584
loss 2.559187412261963
loss 2.5584824085235596
loss 2.5577878952026367
loss 2.557103395462036
loss 2.556429624557495
loss 2.5557656288146973
loss 2.5551114082336426
loss 2.554466962814331
loss 2.5538313388824463
loss 2.5532050132751465
loss 2.

In [76]:
# prediction on inputs
for i in range(5):
    out = []
    # start off with the special character
    ix = 0 
    while True:
        # continuously sample from our neural network
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)

        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        if ix == 0:
            break

        out.append(itos[ix])
    
    print("".join(out))


jirbe
kwxbris
cavn
gdwishon
esive
