In [46]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [7]:
# read in all the words (~32K)
words = open("../names.txt", "r").read().splitlines()

In [8]:
# build the vocabulary of chaacters and mappings to/from integers

chars = sorted(list(set(''.join(words))))
stoi = {s:(i+1) for i,s in enumerate(chars)}
stoi["."] = 0
itos = {s:i for i,s in stoi.items()}


In [32]:
block_size = 3 # context length: how many characters do we take to predict the next one
X, Y = [], []

for w in words:
    context = [0] * block_size

    for ch in (w + '.'):
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)



In [39]:
g = torch.Generator().manual_seed(2147483647)

C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]
sum(p.nelement() for p in parameters) # total number of params


3481

In [40]:
for p in parameters:
    p.requires_grad = True

In [44]:
for k in range(10000):
    # minibatch construct
    ix = torch.randint(0, X.shape[0], (32,))

    # forward pass
    emb = C[X[ix]] # (32, 3, 2)
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2 # (32, 27)
    loss = F.cross_entropy(logits, Y[ix])
    # backward pass
    for p in parameters:
        p.grad = None 
    loss.backward()
    # update
    for p in parameters: 
        p.data += -0.1 * p.grad
    
    if (k % 1000) == 0:
        print(loss.item())

print(loss.item())

2.725968360900879
2.0295872688293457
2.313206911087036
2.3643887042999268
2.5894908905029297
2.5606229305267334
2.521812915802002
2.491724967956543
2.4782650470733643
2.3053174018859863
2.2272562980651855


In [45]:
emb = C[X]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Y)
loss

tensor(2.3803, grad_fn=<NllLossBackward0>)

In [49]:
g = torch.Generator().manual_seed(2147483647)

for _ in range(20):

    out = []
    context = [0] * block_size

    while True:
        emb = C[torch.tensor([context])] 
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()
        context = context[1:] + [ix]
        out.append(ix)
        if ix == 0:
            break

    print(''.join(itos[i] for i in out))
        

caralixa.
jahlei.
kimriohlety.
halan.
kejrahnen.
amerari.
kaqei.
neleniahchaiiv.
kaleigsh.
lu.
noin.
quiltis.
lilei.
jadbi.
wazelo.
diarini.
fane.
nivuli.
edaediia.
gineley.
