In [1]:
import seaborn as sns

sns.set_theme()

In [2]:
# Read the dataset

words = open("names.txt").read().splitlines()
print(f"Using {len(words)} names")
words[:10]

Using 32033 names


['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
# Explore the dataset
chars = sorted(list(set("".join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}

M = len(stoi)
print(M, itos)

27 {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [4]:
import jax.numpy as jnp

context_length = 3

X = []
y = []
for w in words:
    context = [0] * context_length
    for ch in w + ".":
        idx = stoi[ch]
        X.append(context)
        y.append(idx)
        context = context[1:] + [idx]


X = jnp.array(X)
y = jnp.array(y)

X.shape, y.shape

((228146, 3), (228146,))

In [5]:
from jax import random, Array, jit, vmap, value_and_grad
from jax.nn import one_hot, softmax
import jax

#  Define the model
token_space = 27
embedding_space = 2
key = random.key(42)
key, C_key, W1_key, W2_key = random.split(key, 4)
parameters = {
    "C": random.normal(C_key, (token_space, embedding_space)),
    "W1": random.normal(W1_key, (embedding_space * context_length, 100)),
    "W2": random.normal(W2_key, (100, token_space)),
}


@jit
def model(X: Array, parameters: dict[str, Array]):
    emb = jnp.dot(one_hot(X, token_space), parameters["C"]).reshape(
        context_length * embedding_space
    )
    hlogits = jnp.tanh(jnp.dot(emb, parameters["W1"]))
    logits = jnp.dot(hlogits, parameters["W2"])
    probs = softmax(logits)
    return probs


@jit
def criterion(probs: Array, y: int):
    return -jnp.log(probs[y])

In [15]:
@value_and_grad
def forward(parameters: dict[str, Array], X: Array, y: Array):
    preds = vmap(model, in_axes=(0, None))(X, parameters)
    loss = jnp.mean(vmap(criterion)(preds, y))
    return loss

for i in range(20):
    loss, grad = forward(parameters, X, y)
    for k in parameters.keys():
        parameters[k] -= .1 * grad[k]
    print(loss)



13.328408
11.900421
10.83315
10.535785
10.299287
9.456129
9.4127865
8.733887
8.582693
8.2400255
8.023248
7.7062817
7.4838276
7.241955
7.015661
6.794135
6.6001596
6.3675485
6.24095
5.9593678
