In [1]:
# Corpus of around 32K names
words = open("names.txt", "r").read().splitlines()

In [2]:
import torch
import torch.nn.functional as F

In [3]:
chars = sorted(list(set(''.join(words))))
stoi = {s:(i+1) for i,s in enumerate(chars)}
stoi["."] = 0
itos = {s:i for i,s in stoi.items()}

In [4]:
# Create training set of trigrams
xs1, xs2, ys = [], [], []

for w in words:
    chs = ['.'] + ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ind1 = stoi[ch1]
        ind2 = stoi[ch2]
        ind3 = stoi[ch3]

        xs1.append(ind1)
        xs2.append(ind2)
        ys.append(ind3)

xs1 = torch.tensor(xs1)
xs2 = torch.tensor(xs2)
ys = torch.tensor(ys)

In [5]:
# Passes in a training loop (gradient descent)

g = torch.Generator().manual_seed(2147483647)
W = torch.randn([54,27], generator=g, requires_grad=True)
num = xs1.nelement()

for k in range(200):
    # forward pass
    xenc1 = F.one_hot(xs1, num_classes=27).float()
    xenc2 = F.one_hot(xs2, num_classes=27).float()
    xenc = torch.cat((xenc1,xenc2), 1)
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()

    if k % 100 == 0:
        print(loss.item())


    # backward pass
    W.grad = None # set to zero the gradient
    loss.backward()

    # update
    W.data += -50 * W.grad

4.242241859436035
2.394003391265869


In [6]:
g = torch.Generator().manual_seed(2147483647)

for i in range(5):
    out = []

    iix = 0
    ix = 0

    while True:
        xenc1 = F.one_hot(torch.tensor([iix]), num_classes=27).float()
        xenc2 = F.one_hot(torch.tensor([ix]), num_classes=27).float()

        logits = torch.cat((xenc1,xenc2), 1) @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)

        iix = ix
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])

        if ix == 0:
            break 
    print("".join(out))

mfra.
tt.
len.
veroydas.
jalio.
