In [1]:

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

In [None]:
words = open("names.txt").read().splitlines()

In [7]:
def chtoidx(ch):
    if ord(ch) >= ord('a') and ord(ch) <= ord('z'): 
        return ord(ch) - ord('a') + 1
    else: 
        return 0

def idxtoch(idx):
    if idx - 1 + ord('a') >= ord('a') and idx - 1 + ord('a') <= ord('z'):
        return chr(idx - 1 + ord('a'))
    else:
        return '.'

In [100]:
## create datasets
block_size = 5 ## context window fed into the nn to predict the next character

def build_dataset(words):
    X, Y = [], []

    for w in words:
        ex = [chtoidx('.')] * block_size
        for ch in w + '.':
            X.append(ex)
            Y.append(chtoidx(ch))
            ex = ex[1:] + [chtoidx(ch)]

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y

import random
random.seed(129387)
random.shuffle(words)

n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1]) # training dataset
Xdev, Ydev = build_dataset(words[:n1]) # development dataset
Xte, Yte = build_dataset(words[:n1]) # test dataset

In [108]:
## generate neural network
## structure:
## - convert each input context window character into vector with a lookup table (C)
## - feed vectors into hidden layer 1 (L1)
## - feed L1 outputs into logit output layer (L2)

g = torch.Generator().manual_seed(187246324)
ch_embed_dim = 8 ## dimensions of vectors to embed input characters into
n_L1_neurons = 200 ## number of neurons in L1

C = torch.randn((27, ch_embed_dim)) ## lookup table to index character into higher dimensional vectors
W1 = torch.randn((block_size * ch_embed_dim, n_L1_neurons)) ## weights in L1 neurons
b1 = torch.randn(n_L1_neurons) ## biases in L1 neurons
W2 = torch.randn((n_L1_neurons, 27)) ## weights in L2 neurons
b2 = torch.randn(27) ## biases in L2 neurons
params = [C, W1, b1, W2, b2]
for p in params: 
    p.requires_grad = True
sum(p.nelement() for p in params)

13843

In [110]:
## train

lr = 0.01 ## learning rate
batch_size = 200

for i in range(10000):

    batch = torch.randint(0, Xtr.shape[0], (batch_size,))

    ## compute loss
    logits = (C[Xtr[batch]].view(-1, ch_embed_dim * block_size) @ W1 + b1).tanh() @ W2 + b2 ## output of the nn over all exemples in X
    loss = F.cross_entropy(logits, Ytr[batch]) ## applies softmax and negative log likelihood over output of the nn and labels in Y

    print(loss.item())

    ## calculate gradients
    for p in params: 
        p.grad = None
    
    loss.backward()

    ## update params
    for p in params:
        p.data += lr * -p.grad

2.422813892364502
2.5009028911590576
2.381863594055176
2.3139188289642334
2.4660253524780273
2.5111565589904785
2.4576966762542725
2.347952127456665
2.8280911445617676
2.3624117374420166
2.6262547969818115
2.5308268070220947
2.4091291427612305
2.451314687728882
2.4525046348571777
2.4533400535583496
2.6503963470458984
2.512369155883789
2.53171968460083
2.3665223121643066
2.8101589679718018
2.5271003246307373
2.4132797718048096
2.445856809616089
2.591738224029541
2.5298585891723633
2.4976372718811035
2.4351906776428223
2.4729318618774414
2.6392910480499268
2.478109836578369
2.516871690750122
2.3930752277374268
2.518659830093384
2.5432076454162598
2.5765950679779053
2.539994716644287
2.534346342086792
2.5933804512023926
2.5382251739501953
2.485466241836548
2.5795352458953857
2.5487496852874756
2.343205213546753
2.538409471511841
2.3789010047912598
2.506570339202881
2.559648036956787
2.5991976261138916
2.562411308288574
2.5243799686431885
2.406562566757202
2.500676155090332
2.4170262813568

In [114]:
## sample from model

for _ in range(20):
    ctx = [chtoidx('.')] * block_size
    out = []
    while True:
        probs = F.softmax((C[torch.tensor(ctx)].view(-1) @ W1 + b1).tanh() @ W2 + b2, dim=0)
        next_character = torch.multinomial(probs, 1)
        next_character = next_character[0].item()
        if next_character == 0:
            break
        else:
            ctx = ctx[1:] + [next_character]
            out.append(next_character)
    name = ''.join(idxtoch(idx) for idx in out)
    print(name)

savalria
aariandur
andia
narirham
brana
mosisehia
jenrlen
esrm
keyaalson
lana
jirbar
zougan
pyinit
magenda
ehain
takae
moriaha
tange
ryau
jyeglinna
