In [70]:
import torch, math
words = open("names.txt", "r").read().splitlines()


In [71]:
chars = sorted(list(set("".join(words))))
stoi = {s:i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}


In [72]:
# build dataset

def build_dataset(words):
    block_size = 3
    x, y = [], []
    for word in words:
    
        context = [0] * block_size
    
        for ch in word:
            ix = stoi[ch]
            x.append(context)
            y.append(ix)
            context = context[1:] + [ix]
    x = torch.tensor(x)
    y = torch.tensor(y)
    print(f"{x.shape=}, {y.shape=}")
    return x, y

import random
random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
x_train, y_train = build_dataset(words[:n1])
x_dev, y_dev = build_dataset(words[n1:n2])
x_test, y_test = build_dataset(words[n2:])

x.shape=torch.Size([156999, 3]), y.shape=torch.Size([156999])
x.shape=torch.Size([19452, 3]), y.shape=torch.Size([19452])
x.shape=torch.Size([19662, 3]), y.shape=torch.Size([19662])


In [73]:
# building network
g = torch.Generator().manual_seed(31415926)
C = torch.randn((27, 2))
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.rand(27, generator=g)
param = [C, W1, b1, W2, b2]
for p in param:
    p.requires_grad = True

In [78]:
#training with training dataset
for _ in range(10000):
    #forward pass
    ix = torch.randint(0, x_train.shape[0], (32, ))
    emb = C[x_train[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
    logits = h @ W2 + b2
    loss = torch.nn.functional.cross_entropy(logits, y_train[ix])
    #backward pass
    for p in param:
        p.grad = None
    loss.backward()
    lr = 0.1
    for p in param:
        p.data += -lr * p.grad


In [79]:
print(loss.item())

2.4497101306915283


In [80]:
#evaluation with dev dataset
emb = C[x_dev]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
logits = h @ W2 + b2
loss = torch.nn.functional.cross_entropy(logits, y_dev)
loss

tensor(2.4646, grad_fn=<NllLossBackward0>)