In [7]:
import random
from typing import List, NamedTuple, Tuple

import jax
import numpy as np


In [2]:
def load_data(path: str) -> Tuple[List[str], List[str]]:

    with open(path, 'r') as f:
        data = f.read()

    words = data.splitlines()
    words = [word.strip() for word in words] # Remove leading/trailing whitespace
    words = [word for word in words if word] # Remove empty strings

    vocab = sorted(list(set(''.join(words))))
    vocab = ['<eos>'] + vocab
    print(f"number of examples in dataset: {len(words)}")
    print(f"max word length: {max([len(word) for word in words])}")
    print(f"min word length: {min([len(word) for word in words])}")
    print(f"unique characters in dataset: {len(vocab)}")
    print("vocabulary:")
    print(' '.join(vocab))
    print('example for a word:')
    print(words[0])
    return words, vocab

words, vocab = load_data('names.txt')

number of examples in dataset: 32033
max word length: 15
min word length: 2
unique characters in dataset: 27
vocabulary:
<eos> a b c d e f g h i j k l m n o p q r s t u v w x y z
example for a word:
emma


In [3]:
def encode(word: str, vocab: List[str]) -> List[int]:
    """
    Encode a word, add <eos> at the beginning and the end of the word.
    """
    return [vocab.index('<eos>')] + [vocab.index(char) for char in word] + [vocab.index('<eos>')]

def decode(indices: List[int], vocab: List[str]) -> str:
    """
    Decode a list of indices to a word using the vocabulary.
    """
    return ''.join([vocab[index] for index in indices])

for i in range(5):
    print(f"word: {words[i]}")
    print(f"encoded: {encode(words[i], vocab)}")
    print(f"decoded: {decode(encode(words[i], vocab), vocab)}")
    print()

word: emma
encoded: [0, 5, 13, 13, 1, 0]
decoded: <eos>emma<eos>

word: olivia
encoded: [0, 15, 12, 9, 22, 9, 1, 0]
decoded: <eos>olivia<eos>

word: ava
encoded: [0, 1, 22, 1, 0]
decoded: <eos>ava<eos>

word: isabella
encoded: [0, 9, 19, 1, 2, 5, 12, 12, 1, 0]
decoded: <eos>isabella<eos>

word: sophia
encoded: [0, 19, 15, 16, 8, 9, 1, 0]
decoded: <eos>sophia<eos>



In [4]:
encoded_words = [encode(word, vocab) for word in words]
print(encoded_words[0])
print(decode(encoded_words[0], vocab))
print(len(encoded_words))
print(len(encoded_words[0]))

[0, 5, 13, 13, 1, 0]
<eos>emma<eos>
32033
6


In [5]:
def get_dataset(encoded_words: List[List[int]]) -> Tuple[jax.Array, jax.Array]:
    """
    Convert a list of encoded words to a list of bigrams.
    """
    X = []
    y = []
    for word in encoded_words:
        for char1, char2 in zip(word[:-1], word[1:]):
            X.append(char1)
            y.append(char2)
    return jax.numpy.array(X), jax.numpy.array(y)

def get_train_val_test(encoded_words: List[List[int]]) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
    """
    Split the dataset into training, validation and test sets.
    """
    random.shuffle(encoded_words)
    train_words = encoded_words[:int(0.8*len(encoded_words))]
    val_words = encoded_words[int(0.8*len(encoded_words)):int(0.9*len(encoded_words))]
    test_words = encoded_words[int(0.9*len(encoded_words)):]
    X_train, y_train = get_dataset(train_words)
    X_val, y_val = get_dataset(val_words)
    X_test, y_test = get_dataset(test_words)
    return X_train, y_train, X_val, y_val, X_test, y_test

X_train, y_train, X_val, y_val, X_test, y_test = get_train_val_test(encoded_words)

In [6]:
X_train.shape, y_train.shape, X_val.shape, y_val.shape, X_test.shape, y_test.shape

((182581,), (182581,), (22787,), (22787,), (22778,), (22778,))

In [8]:
class Weights(NamedTuple):
    W: jax.Array

In [38]:
def init_weights() -> Weights:
    return Weights(W=np.random.randn(len(vocab), len(vocab)))

def forward(weights: Weights, X: jax.Array) -> jax.Array:
    """
    1) index into the weights matrix W using the input indices
    2) apply the softmax function to obtain a probability distribution over the next character.
    """
    logits = weights.W[X]
    exp_logits = jax.numpy.exp(logits)
    probs = exp_logits / jax.numpy.sum(exp_logits, axis=1, keepdims=True)
    return probs

def loss(weights: Weights, X: jax.Array, y: jax.Array) -> jax.Array:
    """
    1) get the probabilities for the next character
    2) index into the probabilities using the true next character
    3) take the negative log of the probability
    4) return the mean loss over all the examples
    """
    probs = forward(weights, X)
    return -jax.numpy.log(probs[jax.numpy.arange(len(y)), y]).mean()

def update(weights: Weights, X: jax.Array, y: jax.Array, learning_rate: float) -> Weights:
    """
    1) get the probabilities for the next character
    2) compute the gradient of the loss with respect to the weights
    3) update the weights using the gradient
    """
    grads = jax.grad(loss)(weights, X, y)
    return jax.tree.map(lambda w, g: w - learning_rate * g, weights, grads)

@jax.jit
def train_step(weights: Weights, X: jax.Array, y: jax.Array, learning_rate: float) -> Tuple[Weights, float]:
    """
    1) compute the loss
    2) compute the gradient of the loss with respect to the weights
    3) update the weights using the gradient
    4) return the updated weights and the loss
    """
    loss_value = loss(weights, X, y)
    weights = update(weights, X, y, learning_rate)
    return weights, loss_value

In [41]:
weights = init_weights()
N_EPOCHS = 100
LR = 50
for epoch in range(N_EPOCHS):
    weights, loss_value = train_step(weights, X_train, y_train, LR)
    val_loss = loss(weights, X_val, y_val)
    if epoch % 10 == 0:
        print(f"epoch: {epoch}, loss: {loss_value}, val_loss: {val_loss}")


epoch: 0, loss: 3.7597320079803467, val_loss: 3.4014439582824707
epoch: 10, loss: 2.6564254760742188, val_loss: 2.650994300842285
epoch: 20, loss: 2.5583841800689697, val_loss: 2.563549518585205
epoch: 30, loss: 2.524369239807129, val_loss: 2.531252384185791
epoch: 40, loss: 2.5066328048706055, val_loss: 2.514086961746216
epoch: 50, loss: 2.4957680702209473, val_loss: 2.503443956375122
epoch: 60, loss: 2.4884531497955322, val_loss: 2.496213674545288
epoch: 70, loss: 2.4831907749176025, val_loss: 2.490983486175537
epoch: 80, loss: 2.4792237281799316, val_loss: 2.487029790878296
epoch: 90, loss: 2.4761250019073486, val_loss: 2.4839417934417725
