In [1]:
words = open('names.txt').read().splitlines()

In [2]:
# let's only have one special token, and let's have it at index 0, offset others by 1
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
num_classes = len(stoi)

## Create the dataset

In [3]:
import torch
import torch.nn.functional as F

In [4]:
# create the dataset

xs, ys = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1)
        ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print(f'Number of examples {num}')

# initialise the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((num_classes, num_classes), generator=g, requires_grad=True)

Number of examples 228146


In [5]:
# gradient descent
for k in range(200):
    
    # forward pass
    xenc = F.one_hot(xs, num_classes=num_classes).float()
    logits = xenc @ W # interpret this as log counts, another word for these are logits
    counts = logits.exp() # we get the counts, if you think of the bigram model, the counts of each bigram
    probs = counts / counts.sum(1, keepdims=True) # once we have the counts, easy to get probabilities
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() # with regularisation strength
    print(loss.item())
    
    # backward pass
    W.grad = None
    loss.backward()
    
    # update weights
    lr = 50
    W.data += -lr * W.grad

3.7686190605163574
3.3788068294525146
3.161090850830078
3.0271859169006348
2.9344842433929443
2.867231607437134
2.8166539669036865
2.777146339416504
2.745253801345825
2.7188303470611572
2.696505308151245
2.6773722171783447
2.6608052253723145
2.6463513374328613
2.633665084838867
2.622471570968628
2.6125476360321045
2.6037068367004395
2.595794916152954
2.5886809825897217
2.5822560787200928
2.5764293670654297
2.5711236000061035
2.566272735595703
2.5618226528167725
2.5577263832092285
2.5539438724517822
2.550442695617676
2.5471925735473633
2.5441696643829346
2.5413525104522705
2.538721799850464
2.536262273788452
2.5339579582214355
2.531797409057617
2.5297679901123047
2.527860164642334
2.5260636806488037
2.5243709087371826
2.522773265838623
2.52126407623291
2.519836664199829
2.5184857845306396
2.5172057151794434
2.515990734100342
2.5148372650146484
2.5137407779693604
2.51269793510437
2.511704444885254
2.5107579231262207
2.509854793548584
2.5089924335479736
2.5081682205200195
2.50738024711608

## Finally lets sample from the model

In [6]:
g = torch.Generator().manual_seed(2147483647)

for i in range(10):
    out = []
    ix = 0
    while True:
        # NN inference
        xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)

        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

mor.
axx.
minaymoryles.
kondmaisah.
anchshizarie.
odaren.
iaddash.
h.
jionatien.
egwulo.


In [7]:
# pretty crazy that it's very near to the generation as to the bigram model

# mor.
# axx.
# minaymoryles.
# kondlaisah.
# anchshizarie.
# odaren.
# iaddash.
# h.
# jhinatien.
# egushl.