In [1]:
import torch
import torch.nn.functional as F

with open("names.txt", "r") as f:
    words = f.read().splitlines()
chars = sorted(list(set("".join(words))))
stoi = {c: i + 1 for i, c in enumerate(chars)}
stoi["."] = 0
itos = {i: c for c, i in stoi.items()}

In [2]:
# prepare data
X = []
y = []
block_size = 3
for word in words:
    context = [0] * block_size
    for ch in word + ".":
        ix = stoi[ch]
        X.append(context)
        y.append(ix)
        context = context[1:] + [ix]
X = torch.tensor(X)
y = torch.tensor(y)

In [3]:
X.shape, y.shape

(torch.Size([228146, 3]), torch.Size([228146]))

In [4]:
C = torch.randn((27, 2))
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([ 0.9381, -0.0607])

In [5]:
emb = C[X]
W1 = torch.randn((6, 100))
b1 = torch.randn(100)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)

In [6]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [7]:
logits = h @ W2 + b2
logits.shape

torch.Size([228146, 27])

In [8]:
counts = logits.exp()
prob = counts / counts.sum(-1, keepdim=True)
prob.shape

torch.Size([228146, 27])

In [9]:
loss = -prob[torch.arange(len(y)), y].log().mean()
print(f"{loss.item():.2f}")

16.83


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [11]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g).to(device)
W1 = torch.randn((6, 100), generator=g).to(device)
b1 = torch.randn(100, generator=g).to(device)
W2 = torch.randn((100, 27), generator=g).to(device)
b2 = torch.randn(27, generator=g).to(device)
X = X.to(device)
y = y.to(device)
parameters = [C, W1, b1, W2, b2]

In [12]:
print(X.shape, y.shape)

torch.Size([228146, 3]) torch.Size([228146])


In [13]:
print(sum(p.nelement() for p in parameters))
for p in parameters:
    p.requires_grad = True

3481


In [15]:
for _ in range(10000):
    # forward pass
    ix = torch.randint(0, X.shape[0], (32,))
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)  # (N, 100)
    logits = h @ W2 + b2  # (N, 27)
    loss = F.cross_entropy(logits, y[ix])
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data -= 0.1 * p.grad
    print(f"Epoch {_}: {loss.item():.2f}")

Epoch 0: 2.71
Epoch 1: 2.66
Epoch 2: 2.85
Epoch 3: 2.79
Epoch 4: 2.88
Epoch 5: 2.88
Epoch 6: 2.50
Epoch 7: 2.98
Epoch 8: 3.38
Epoch 9: 2.53
Epoch 10: 2.65
Epoch 11: 2.80
Epoch 12: 2.44
Epoch 13: 2.68
Epoch 14: 2.68
Epoch 15: 2.39
Epoch 16: 2.30
Epoch 17: 2.95
Epoch 18: 2.47
Epoch 19: 2.75
Epoch 20: 2.66
Epoch 21: 2.82
Epoch 22: 2.59
Epoch 23: 2.50
Epoch 24: 2.66
Epoch 25: 2.91
Epoch 26: 2.56
Epoch 27: 2.45
Epoch 28: 2.58
Epoch 29: 2.40
Epoch 30: 2.40
Epoch 31: 2.59
Epoch 32: 2.58
Epoch 33: 2.76
Epoch 34: 2.58
Epoch 35: 2.87
Epoch 36: 2.66
Epoch 37: 2.58
Epoch 38: 2.42
Epoch 39: 2.76
Epoch 40: 2.80
Epoch 41: 2.52
Epoch 42: 2.36
Epoch 43: 2.62
Epoch 44: 2.38
Epoch 45: 2.46
Epoch 46: 2.80
Epoch 47: 2.66
Epoch 48: 2.54
Epoch 49: 2.36
Epoch 50: 2.49
Epoch 51: 2.59
Epoch 52: 2.46
Epoch 53: 2.95
Epoch 54: 2.55
Epoch 55: 2.43
Epoch 56: 3.10
Epoch 57: 2.66
Epoch 58: 2.83
Epoch 59: 2.35
Epoch 60: 2.71
Epoch 61: 2.48
Epoch 62: 2.85
Epoch 63: 2.80
Epoch 64: 2.60
Epoch 65: 2.69
Epoch 66: 2.72
Epoch