In [1]:
from pathlib import Path

import torch
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [63]:
# Create dataset

words = Path.cwd().joinpath("names.txt").read_text().splitlines()
end_tok = "."
chars = [end_tok] + sorted(list(set("".join(words))))
char_to_index = {char: index for index, char in enumerate(chars)}
n_chars = len(chars)

block_size = 3
X, Y = [], []
for w in words:
    context = [0] * block_size
    for ch in w + end_tok:
        ix = char_to_index[ch]
        X.append(context)
        Y.append(ix)
        # print(''.join(chars[i] for i in context), "->", ch)
        context = context[1:] + [ix]

X = torch.tensor(X)
Y = torch.tensor(Y)
num = X.nelement()

In [160]:
n_embeddings_dims = 2

C = torch.randn(27, n_embeddings_dims)
layer1_size = 100
W1 = torch.randn((n_embeddings_dims * block_size, layer1_size))
b1 = torch.randn(layer1_size)
layer2_size = 27
W2 = torch.randn((layer1_size, layer2_size))
b2 = torch.randn(layer2_size)
parameters = [C, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

In [161]:
steps = 10000
lre = torch.linspace(-3, 0, steps)
lrs = 10**lre

In [178]:
batch_size = 128
lri = []
lossi = []

for i in range(steps):
    # Get minibatch
    ix = torch.randint(0, X.shape[0], (batch_size,))

    # Forward pass
    emb = C[X[ix]]
    h = torch.tanh(emb.view(-1, n_embeddings_dims * block_size) @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])
    print(loss.item())

    # Backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    lr = .01
    for p in parameters:
        p.data += -1 * p.grad

    # lri.append(lre[i])
    # lossi.append(loss.item())

2.207943916320801
2.4936907291412354
2.2837698459625244
2.3628122806549072
2.332258939743042
2.256321907043457
2.4115853309631348
2.3658108711242676
2.3872833251953125
2.3019521236419678
2.374828338623047
2.3217172622680664
2.1818506717681885
2.3821544647216797
2.2874157428741455
2.289034605026245
2.2500579357147217
2.369939088821411
2.3196535110473633
2.485610008239746
2.377528667449951
2.3856496810913086
2.336794853210449
2.2810988426208496
2.423938751220703
2.2370502948760986
2.4014840126037598
2.3926210403442383
2.369875192642212
2.4439895153045654
2.352181911468506
2.37534761428833
2.4234063625335693
2.3975963592529297
2.1979172229766846
2.3912012577056885
2.1495471000671387
2.3730130195617676
2.509986400604248
2.4678218364715576
2.63167667388916
2.552236795425415
2.411578416824341
2.2421224117279053
2.3666064739227295
2.346351146697998
2.430196523666382
2.415712594985962
2.283627986907959
2.154448986053467
2.536820650100708
2.258124351501465
2.417584180831909
2.5559399127960205
2

In [179]:
# Eval
emb = C[X]
h = torch.tanh(emb.view(-1, n_embeddings_dims * block_size) @ W1 + b1)
logits = h @ W2 + b2
loss = F.cross_entropy(logits, Y)
loss.item()

2.346143960952759