In [1]:
import torch
import torch.nn.functional as F

with open("names.txt", "r") as f:
    words = f.read().splitlines()
chars = sorted(list(set("".join(words))))
stoi = {c: i + 1 for i, c in enumerate(chars)}
stoi["."] = 0
itos = {i: c for c, i in stoi.items()}

In [2]:
# prepare data
X = []
y = []
block_size = 3
for word in words:
    context = [0] * block_size
    for ch in word + ".":
        ix = stoi[ch]
        X.append(context)
        y.append(ix)
        context = context[1:] + [ix]
X = torch.tensor(X)
y = torch.tensor(y)

In [3]:
X.shape, y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [4]:
C = torch.randn((27, 2))
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([-0.2358,  1.5623])

In [5]:
emb = C[X]
W1 = torch.randn((6, 100))
b1 = torch.randn(100)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [6]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [7]:
logits = h @ W2 + b2
logits.shape

torch.Size([228146, 27])

In [8]:
counts = logits.exp()
prob = counts / counts.sum(-1, keepdim=True)
prob.shape

torch.Size([228146, 27])

In [9]:
loss = -prob[torch.arange(len(y)), y].log().mean()
print(f"{loss.item():.2f}")

17.92


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [11]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g).to(device)
W1 = torch.randn((6, 100), generator=g).to(device)
b1 = torch.randn(100, generator=g).to(device)
W2 = torch.randn((100, 27), generator=g).to(device)
b2 = torch.randn(27, generator=g).to(device)
X = X.to(device)
y = y.to(device)
parameters = [C, W1, b1, W2, b2]

In [12]:
print(X.shape, y.shape)

torch.Size([228146, 3]) torch.Size([228146])


In [13]:
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

3481


In [14]:
for _ in range(10000):
    # forward pass
    ix = torch.randint(0, X.shape[0], (32,))
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)  # (N, 100)
    logits = h @ W2 + b2  # (N, 27)
    loss = F.cross_entropy(logits, y[ix])
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data -= 0.1 * p.grad
    print(f"Epoch {_}: {loss.item():.2f}")

Epoch 0: 20.25
Epoch 1: 17.59
Epoch 2: 15.60
Epoch 3: 14.19
Epoch 4: 14.16
Epoch 5: 12.68
Epoch 6: 12.20
Epoch 7: 9.58
Epoch 8: 13.82
Epoch 9: 10.96
Epoch 10: 12.46
Epoch 11: 13.24
Epoch 12: 12.02
Epoch 13: 11.57
Epoch 14: 9.42
Epoch 15: 8.47
Epoch 16: 8.04
Epoch 17: 7.75
Epoch 18: 8.28
Epoch 19: 10.05
Epoch 20: 7.90
Epoch 21: 9.91
Epoch 22: 9.92
Epoch 23: 5.98
Epoch 24: 5.83
Epoch 25: 8.96
Epoch 26: 6.27
Epoch 27: 8.22
Epoch 28: 8.15
Epoch 29: 6.19
Epoch 30: 6.32
Epoch 31: 6.01
Epoch 32: 5.76
Epoch 33: 7.88
Epoch 34: 5.57
Epoch 35: 7.93
Epoch 36: 8.40
Epoch 37: 5.31
Epoch 38: 8.82
Epoch 39: 7.91
Epoch 40: 5.17
Epoch 41: 4.40
Epoch 42: 4.89
Epoch 43: 5.01
Epoch 44: 4.92
Epoch 45: 4.62
Epoch 46: 4.99
Epoch 47: 6.54
Epoch 48: 4.90
Epoch 49: 5.39
Epoch 50: 4.32
Epoch 51: 6.35
Epoch 52: 5.06
Epoch 53: 4.85
Epoch 54: 4.49
Epoch 55: 5.12
Epoch 56: 4.01
Epoch 57: 5.09
Epoch 58: 4.06
Epoch 59: 5.62
Epoch 60: 4.66
Epoch 61: 3.92
Epoch 62: 4.44
Epoch 63: 3.66
Epoch 64: 4.33
Epoch 65: 4.76
Epoch 