In [8]:
import seaborn as sns
import jax

sns.set_theme()

In [9]:
# Read the dataset

words = open("names.txt").read().splitlines()
print(f"Using {len(words)} names")
words[:10]

Using 32033 names


['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [10]:
# Explore the dataset
chars = sorted(list(set("".join(words))))
stoi = {s: i + 1 for i, s in enumerate(chars)}
stoi["."] = 0
itos = {i: s for s, i in stoi.items()}

M = len(stoi)
print(M, itos)

27 {1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [11]:
import jax.numpy as jnp

context_length = 3

X = []
y = []
for w in words:
    context = [0] * context_length
    for ch in w + ".":
        idx = stoi[ch]
        X.append(context)
        y.append(idx)
        context = context[1:] + [idx]


X = jnp.array(X)
y = jnp.array(y)

X.shape, y.shape

jax.default_backend()

'METAL'

In [16]:
from jax import random, Array, jit, vmap, value_and_grad
from jax.nn import one_hot, softmax
import jax

#  Define the model
token_space = 27
embedding_space = 2
key = random.key(42)
key, C_key, W1_key, W2_key = random.split(key, 4)
parameters = {
    "C": random.normal(C_key, (token_space, embedding_space)),
    "W1": random.normal(W1_key, (embedding_space * context_length, 100)) / 10,
    "W2": random.normal(W2_key, (100, token_space)),
}


@jit
def model(X: Array, parameters: dict[str, Array]):
    emb = jnp.dot(one_hot(X, token_space), parameters["C"]).reshape(
        context_length * embedding_space
    )
    hlogits = jnp.tanh(jnp.dot(emb, parameters["W1"]))
    logits = jnp.dot(hlogits, parameters["W2"])
    probs = softmax(logits)
    return probs


@jit
def criterion(probs: Array, y: int):
    return -jnp.log(probs[y])

In [15]:
-jnp.log(1/27)

Array(3.295837, dtype=float32, weak_type=True)

In [17]:
@value_and_grad
def forward(parameters: dict[str, Array], X: Array, y: Array):
    preds = vmap(model, in_axes=(0, None))(X, parameters)
    loss = jnp.mean(vmap(criterion)(preds, y))
    return loss

for i in range(2000):
    loss, grad = forward(parameters, X, y)
    for k in parameters.keys():
        parameters[k] -= .1 * grad[k]
    if (i % 100 == 0):
        print(f"Epoch {i + 1}: {loss}")



Epoch 1: 4.612331867218018
Epoch 101: 2.819538116455078
Epoch 201: 2.7806360721588135
Epoch 301: 2.7291812896728516
Epoch 401: 2.6902899742126465
Epoch 501: 2.657661199569702
Epoch 601: 2.6299777030944824
Epoch 701: 2.6060791015625
Epoch 801: 2.5852952003479004
Epoch 901: 2.5673162937164307
Epoch 1001: 2.551971435546875
Epoch 1101: 2.539088487625122
Epoch 1201: 2.528407096862793
Epoch 1301: 2.5195651054382324
Epoch 1401: 2.512166738510132
Epoch 1501: 2.5058505535125732
Epoch 1601: 2.500328779220581
Epoch 1701: 2.4953911304473877
Epoch 1801: 2.4908902645111084
Epoch 1901: 2.4867217540740967
