In [1]:
# grab all the names
words = open('names.txt', 'r').read().splitlines()

In [2]:
# create the counts. we can reuse our previous functions here
all_characters = sorted(list(set("".join(words))))
# string to int, a = 1, b = 2, ...
stoi = {s: i+1 for i, s in enumerate(all_characters)}
# special char ".", would assume that there is no one letter name
stoi['.'] = 0
# int to string mapping'
itos = {i: s for s, i in stoi.items()}

In [3]:
# quickly check for one letter names
sorted(len(w) for w in words)[0]

2

In [4]:
# create the counts table. Since this is a trigram instead of a bigram, we need 3 dimensions
import torch

In [5]:
N = torch.zeros((27, 27, 27), dtype=torch.int32)
N.shape

torch.Size([27, 27, 27])

In [6]:
# loop through words and start counting
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        N[ix1, ix2, ix3] += 1

In [7]:
print(N[0, 0, 0]) # would be "...", which should never happen.
print(N[5, 13, 13]) # emm, which should appear since "emma" is a name in our dataset
print(N[0, 2, :]) # names starting with b, looks about right

tensor(0, dtype=torch.int32)
tensor(100, dtype=torch.int32)
tensor([  0, 169,   0,   0,   0, 253,   0,   0,   9,  41,   1,   0,  85,   0,
          0,  77,   0,   0, 646,   0,   0,  21,   0,   0,   0,   4,   0],
       dtype=torch.int32)


In [8]:
# I think the way this NN should look like is that it takes xs in the shape of 
# 2 * 27 classes and returns a guess for a ch3

In [9]:
# training point [x_1, x_2], label y
xs, ys = [], []
# collect training points and labels
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append((ix1, ix2))
        ys.append(ix3)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [10]:
import torch.nn.functional as F
# encode x, note that we first encode x_1 and x_2, then we concat them over dim1 (column).
# This will result in concatting the COLUMNS. Row concat is esentially like appending a list to another, which is not what we want.
x_enc = torch.cat([F.one_hot(xs[:, 0], num_classes=27).float(), 
                             F.one_hot(xs[:, 1], num_classes=27).float()], dim=1)
x_enc.shape

torch.Size([196113, 54])

In [11]:
# init the neural network
W = torch.randn((54, 27), requires_grad=True)
num = int(xs.nelement() / 2)

In [12]:
for k in range(100):
    # forward pass
    x_enc = torch.cat([F.one_hot(xs[:, 0], num_classes=27).float(), 
                             F.one_hot(xs[:, 1], num_classes=27).float()], dim=1)
    logits = x_enc @ W # predict log counts
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(num), ys].log().mean()
    if k % 10 == 0:
        print(k, loss.item())
    # backward pass
    W.grad = None # set grad to 0
    loss.backward()

    # update
    W.data += -30 * W.grad

0 4.326144695281982
10 2.6685147285461426
20 2.4744510650634766
30 2.3979923725128174
40 2.357487201690674
50 2.3328921794891357
60 2.31642484664917
70 2.3045926094055176
80 2.2956583499908447
90 2.288663864135742


In [13]:
for i in range(26):
    out = [itos[i+1]]
    ix = [0, i+1]
    while True:
        x_enc = torch.cat([F.one_hot(torch.tensor(ix[0]), num_classes=27).float(), 
                             F.one_hot(torch.tensor(ix[1]), num_classes=27).float()], dim=0)
        logits = x_enc @ W
        counts = logits.exp()
        p = counts / counts.sum(0, keepdims=True)

        ix = [ix[1], torch.multinomial(p, num_samples=1, replacement=True).item()]
        out.append(itos[ix[1]])
        if ix[1] == 0:
            break;
                                                                                                                                                                   
    print("".join(out))
        

adellyna.
beelynis.
co.
dynneleel.
eley.
fikay.
getynn.
hais.
iolux.
jayn.
kaibrisjai.
lany.
macwro.
ni.
od.
prenzayishorf.
qh.
ramhlerani.
sovina.
telyahduzadilpfiasharin.
ua.
vakijah.
wilon.
xenn.
ylo.
zalyath.
