In [1]:
%pip install pytreeclass jaxtyping jaxlib

Collecting equinox
  Using cached equinox-0.10.3-py3-none-any.whl (111 kB)
Collecting jaxtyping
  Downloading jaxtyping-0.2.19-py3-none-any.whl (24 kB)
Collecting jaxlib
  Downloading jaxlib-0.4.10-cp311-cp311-macosx_11_0_arm64.whl (59.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0mm
[?25hCollecting jax>=0.4.4 (from equinox)
  Downloading jax-0.4.10.tar.gz (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting typeguard>=2.13.3 (from jaxtyping)
  Downloading typeguard-4.0.0-py3-none-any.whl (33 kB)
Collecting ml-dtypes>=0.1.0 (from jaxlib)
  Using cached ml_dtypes-0.1.0-cp311-cp311-macosx_10_9_universal2.whl (317 kB)
Colle

In [2]:
from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np
import re
from jaxtyping import Array, Float, Int
import pytreeclass as pytc
from typing import Generator
from functools import partial
from copy import deepcopy


def prepare_text(file_name, sentence_length):
    with open(file_name, "r+") as file:
        all_text = file.read()
        # all_text = all_text.replace('\n', ' ').replace('  : ', '')

    # Define a regular expression pattern to match all punctuation marks
    punctuation_pattern = r"[^\w\s]"

    # Define a regular expression pattern to match words with apostrophes
    apostrophe_pattern = r"\w+(?:\'\w+)?"
    # Define a regular expression pattern to match newlines
    newline_pattern = r"\n"

    # Combine the three patterns to match all tokens
    token_pattern = (
        punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern
    )

    # Split the text into tokens, including words with apostrophes as separate tokens
    all_words = re.findall(token_pattern, all_text.lower())
    vocab = list(set(all_words))

    vocab_one_hot_indicies = jnp.array(
        [vocab.index(t) for t in all_words], dtype=jnp.int32
    )
    split_indicies = vocab_one_hot_indicies[
        : (len(vocab) // sentence_length) * sentence_length
    ].reshape(len(vocab) // sentence_length, sentence_length)
    # make last word random, shouldn't make too much of an impact (could be better handled with special char?)
    split_indicies_labels = jnp.concatenate(
        (
            vocab_one_hot_indicies[
                1 : ((len(vocab) - 1) // sentence_length) * sentence_length
            ],
            jnp.array([0]),
        )
    ).reshape((len(vocab) - 1) // sentence_length, sentence_length)
    partition_index = 6 * int(len(split_indicies) / 7)
    train = split_indicies[:partition_index]
    train_labels = split_indicies_labels[:partition_index]
    valid = split_indicies[partition_index:]
    valid_labels = split_indicies_labels[partition_index:]

    return train, train_labels, valid, valid_labels, vocab

ModuleNotFoundError: No module named 'pytreeclass'

In [None]:
file_name = "one-fish-two-fish.txt"
sentence_length = 8  # keep even because of how we split the data
train, train_labels, valid, valid_labels, vocab = prepare_text(
    file_name, sentence_length
)

print(f"examples from vocab: {vocab[:10]}")
print(f"total length of vocab: {len(vocab)} unique words")
print(
    f"total length of training data: {len(train)} sentences (each {sentence_length} words)"
)
print(
    f"total length of validation data: {len(valid)} sentences (each {sentence_length} words)"
)

In [117]:
# first sentence in train set
train[0]

Array([  2, 186, 233, 227, 186, 233,  83, 186], dtype=int32)

In [118]:
# first sentence in train labels == same sentence shifted by one word
# i.e. equivalent to train[0][1:] + train[1][0]
train_labels[0]

Array([186, 233, 227, 186, 233,  83, 186, 233], dtype=int32)

In [119]:
# we can reconstruct a sentence by mapping indicies back to words
" ".join([vocab[i] for i in train[0]])

'one fish , two fish , red fish'

In [120]:
" ".join([vocab[i] for i in train_labels[0]])

'fish , two fish , red fish ,'

In [121]:
fish_idx = vocab.index("fish")  # the index of the word "fish" in the vocab
one_hot_fish = np.zeros(len(vocab))  # a vector of zeros with length equal to the vocab
one_hot_fish[fish_idx] = 1  # set the index of the word "fish" to 1

# the syntax in JAX is a little different, but the idea is the same
# we use the `at` method to set the value at a particular index,
# and the `set` method to set the value at that index to 1.
# this is due to the fact that JAX arrays are immutable,
# so we can't just set the value at an index to 1 directly!
fish_idx = vocab.index("fish")
one_hot_fish = jnp.zeros(len(vocab))
one_hot_fish = one_hot_fish.at[fish_idx].set(1)

one_hot_fish

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0.

In [122]:
def one_hot_sentence(
    sentence: Int[Array, "sentence"], vocab_size: int
) -> Int[Array, "sentence vocab"]:
    return jnp.array([jnp.zeros((vocab_size,)).at[word].set(1) for word in sentence])


# make a very intelligent sentence of all "fish"
fish_sentence = jnp.array([fish_idx] * 10)

# one hot encode the sentence
one_hot_fish_sentence = one_hot_sentence(fish_sentence, len(vocab))

# assert that the sentence is one hot encoded correctly
assert jnp.all(one_hot_fish_sentence == jnp.array([one_hot_fish] * 10))

In [123]:
# we can use `vmap` to automatically transform the function to work on a batch of sentences!
# this will be useful when we want to train our model on multiple sentences at once.
# note that we need to specify the `in_axes` argument to tell JAX which argument
# in the function is the one that we want to map over (in this case, we want to
# map over the first axis of the `sentence` argument, indicated by `0`).
# we also need to specify `None` for the `vocab_size` argument, since it is not
# being mapped over -- it is the same for every sentence in the batch.
batch_one_hot = jax.vmap(one_hot_sentence, in_axes=(0, None))

In [124]:
# This is a container for all the free parameters of an RNN!
# You can see the shapes of each attribute from the type annotations.
# We have a couple of sizes: hidden_state, embedding, vocab
# -> these represent the size of the hidden state weights,
#    the embedding matrix, and the vocabulary respectively.
# This parameters object will be passed to most functions below:
# e.g. access the output weights by calling `params.output_weights` etc.
class Parameters(pytc.TreeClass):
    embedding_weights: Float[Array, "hidden_state embedding"]
    hidden_state_weights: Float[Array, "hidden_state hidden_state"]
    output_weights: Float[Array, "vocab hidden_state"]
    hidden_state_bias: Float[Array, "hidden_state"]
    output_bias: Float[Array, "vocab"]
    embedding_matrix: Float[Array, "embedding vocab"]


# we'll initialize our parameters randomly, but close to 0/identity so that
# we don't have exploding gradients later on!

# set sizes for embeddings, hidden state, vocab, and output vectors
e = 30
h = 16
v = len(vocab)
o = v

params = Parameters(
    embedding_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[h, e], key=jax.random.PRNGKey(0)
    ),
    hidden_state_weights=jnp.identity(h),
    output_weights=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[o, h], key=jax.random.PRNGKey(0)
    ),
    hidden_state_bias=jnp.zeros((h,)),
    output_bias=jnp.zeros(
        shape=[
            o,
        ]
    ),
    embedding_matrix=jax.random.truncated_normal(
        lower=-0.1, upper=0.1, shape=[e, v], key=jax.random.PRNGKey(0)
    ),
)

# let's inspect the structure of our parameters
print(pytc.tree_summary(params))
print(pytc.tree_diagram(params))

┌─────────────────────┬───────────┬──────┐
│Name                 │Type       │Count │
├─────────────────────┼───────────┼──────┤
│.embedding_weights   │f32[16,30] │480   │
├─────────────────────┼───────────┼──────┤
│.hidden_state_weights│f32[16,16] │256   │
├─────────────────────┼───────────┼──────┤
│.output_weights      │f32[300,16]│4,800 │
├─────────────────────┼───────────┼──────┤
│.hidden_state_bias   │f32[16]    │16    │
├─────────────────────┼───────────┼──────┤
│.output_bias         │f32[300]   │300   │
├─────────────────────┼───────────┼──────┤
│.embedding_matrix    │f32[30,300]│9,000 │
├─────────────────────┼───────────┼──────┤
│Σ                    │Parameters │14,852│
└─────────────────────┴───────────┴──────┘
Parameters
├── .embedding_weights=f32[16,30](μ=-0.00, σ=0.06, ∈[-0.10,0.10])
├── .hidden_state_weights=f32[16,16](μ=0.06, σ=0.24, ∈[0.00,1.00])
├── .output_weights=f32[300,16](μ=-0.00, σ=0.06, ∈[-0.10,0.10])
├── .hidden_state_bias=f32[16](μ=0.00, σ=0.00, ∈[0.00,0.00])


We'll come to use most of these values later, but for now, we're just focused on embeddings!

Recall what we did in the previous sessions, looking at RNNs for *language modelling*, where every word is turned into a one-hot vector (or "token"). We then multiply these one-hot words by an *embedding matrix* $E$, which multiplies the words to reduce the dimension of that long one-hot vector (=size of the whole vocabulary) to some specified lower dimensional representation (normally ~100 ish). This embedded word is then used to update the hidden state $h$ of the RNN.

Use the embedding matrix (accessible through `params.embedding_matrix`) to fill in the function below, which embeds a single word.

In [125]:
def make_embeddings(
    one_hot_word: Float[Array, "vocab"], params: Parameters
) -> Float[Array, "embedding"]:
    return params.embedding_matrix @ one_hot_word


# I should be a vector of length `e`
assert make_embeddings(one_hot_fish, params).shape == (e,)

In [126]:
# map to work over sentences for later!
embeddings_map = jax.vmap(make_embeddings, in_axes=(0, None))

In [127]:
def update_hidden_state(
    embedding: Float[Array, "embedding"],
    hidden_state: Float[Array, "hidden_state"],
    params: Parameters,
) -> Float[Array, "hidden_state"]:
    return jax.nn.tanh(
        params.hidden_state_weights @ hidden_state
        + params.embedding_weights @ embedding
        + params.hidden_state_bias
    )


# I should be a vector of length `h`
embedding = make_embeddings(one_hot_fish, params)
assert update_hidden_state(embedding, jnp.zeros((h,)), params).shape == (h,)

In [128]:
def output(
    hidden_state: Float[Array, "hidden_state"], params: Parameters
) -> Float[Array, "vocab"]:
    return jax.nn.softmax(params.output_weights @ hidden_state + params.output_bias)


# I should be a vector of length `v`
assert output(jnp.zeros((h,)), params).shape == (v,)

In [129]:
def rnn(
    data: Float[Array, "sentence vocab"], params: Parameters, hidden_size: int
) -> Float[Array, "sentence vocab"]:
    # apply embeddings_map to create a vector of embeddings
    embeddings = embeddings_map(data, params)  # ["sentence embedding"]

    # initialize the hidden state with zeros
    hidden_state = jnp.zeros((hidden_size,))

    # for each word in the vector of embeddings:
    #   > update the hidden state
    #   > compute the output word using that hidden state and store it
    # return the set of outputs
    outputs = []

    for word in embeddings:
        hidden_state = update_hidden_state(word, hidden_state, params)
        outputs.append(output(hidden_state, params))

    return jnp.array(outputs)


# make a very intelligent sentence of all "fish"
fish_sentence = jnp.array([fish_idx] * 10)
one_hot_fish_sentence = one_hot_sentence(fish_sentence, len(vocab))

# run the RNN on "fish fish fish fish fish fish fish fish fish fish"
rnn_outputs = rnn(one_hot_fish_sentence, params, h)

# the output should be a list of 10 vectors of length `v`
# corresponding to the output probabilities at each position in the sentence
assert rnn_outputs.shape == (10, v)

# what is the most likely word at each position?
# (remember that we have randomly initialized our parameters, so this will be nonsense!)
most_likely_words = [vocab[jnp.argmax(output)] for output in rnn_outputs]
most_likely_words

['then', 'then', 'then', 'then', 'then', 'then', 'then', 'then', 'sun', 'sun']

In [130]:
def loss(
    output: Float[Array, "vocab"], next_one_hot_word: Float[Array, "vocab"]
) -> Float[Array, ""]:
    # index the softmax probs at the word of interest
    return -jnp.log(output[jnp.argmax(next_one_hot_word)])


sentence_loss = jax.vmap(loss, in_axes=(0, 0))

In [131]:
def forward_pass(
    data: Float[Array, "sentence vocab"],
    next_words: Float[Array, "sentence vocab"],  # data shifted by 1 to the right
    params: Parameters,
    hidden_size: int,
) -> Float[Array, ""]:
    output = rnn(data, params, hidden_size)
    return sentence_loss(output, next_words).mean(axis=0)


# run the forward pass on our fish sentence
loss_value = forward_pass(one_hot_fish_sentence, one_hot_fish_sentence, params, h)
assert loss_value.shape == ()
print(f"starting loss: {loss_value:.2f}")

# here, we transform the forward pass into the gradient function,
# and also vmap again so it can handle a batch of sentences instead of one.
loss_and_gradient = jax.value_and_grad(forward_pass, argnums=2)
print("gradients are packed up in a Parameters object:")
print(
    pytc.tree_diagram(
        loss_and_gradient(one_hot_fish_sentence, one_hot_fish_sentence, params, h)[1]
    )
)

starting loss: 5.69
gradients are packed up in a Parameters object:
Parameters
├── .embedding_weights=f32[16,30](μ=-0.00, σ=0.02, ∈[-0.05,0.05])
├── .hidden_state_weights=f32[16,16](μ=0.00, σ=0.02, ∈[-0.04,0.06])
├── .output_weights=f32[300,16](μ=-0.00, σ=0.01, ∈[-0.15,0.20])
├── .hidden_state_bias=f32[16](μ=-0.11, σ=0.28, ∈[-0.50,0.32])
├── .output_bias=f32[300](μ=0.00, σ=0.06, ∈[-1.00,0.00])
└── .embedding_matrix=f32[30,300](μ=-0.00, σ=0.00, ∈[-0.15,0.07])


In [132]:
# we can also vmap the gradient function to handle a batch of sentences,
# and jit it to make it faster!
batched_grads = jax.jit(
    jax.vmap(loss_and_gradient, in_axes=(0, 0, None, None)), static_argnums=(3,)
)

In [133]:
def predict_next_words(
    prompt: str,
    vocab: list[str],
    rnn_params: Parameters,
    rnn_hidden_size: int,
    num_predicted_tokens: int,
    include_prompt=True,
) -> str:
    # Define a regular expression pattern to match all punctuation marks
    punctuation_pattern = r"[^\w\s]"

    # Define a regular expression pattern to match words with apostrophes
    apostrophe_pattern = r"\w+(?:\'\w+)?"
    # Define a regular expression pattern to match newlines
    newline_pattern = r"\n"

    # Combine the three patterns to match all tokens
    token_pattern = (
        punctuation_pattern + "|" + apostrophe_pattern + "|" + newline_pattern
    )

    tokens = re.findall(token_pattern, prompt.lower())
    one_hot_indicies = jnp.array([vocab.index(t) for t in tokens], dtype=jnp.int32)
    sentence = one_hot_sentence(one_hot_indicies, len(vocab))
    embeddings = embeddings_map(sentence, rnn_params)  # ["sentence embedding"]

    hidden_state = jnp.zeros((rnn_hidden_size,))
    outputs = [None] * num_predicted_tokens
    for word in embeddings[:-1]:
        hidden_state = update_hidden_state(word, hidden_state, rnn_params)
    hidden_state = update_hidden_state(embeddings[-1], hidden_state, rnn_params)
    outputs[0] = output(hidden_state, rnn_params)

    for i in range(1, num_predicted_tokens):
        embedded_pred = make_embeddings(outputs[i - 1], rnn_params)
        hidden_state = update_hidden_state(embedded_pred, hidden_state, rnn_params)
        outputs[i] = output(hidden_state, rnn_params)

    res = jnp.array(outputs)
    res_indicies = jnp.argmax(res, axis=1)
    words = [vocab[i] for i in res_indicies]
    out = " ".join(words)
    return prompt + " | " + out if include_prompt else out

In [135]:
batch_size = 400

import numpy.random as npr


def batches(training_data: Array, batch_size: int) -> Generator:
    num_train = training_data.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # batching mechanism, ripped from the JAX docs :)
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size : (i + 1) * batch_size]
                yield train[batch_idx], train_labels[batch_idx]

    return data_stream()


batch = batches(train, batch_size)
one_hot_valid, one_hot_valid_labels = batch_one_hot(valid, len(vocab)), batch_one_hot(
    valid_labels, len(vocab)
)

#### Setting up training hyperparameters

In [137]:
# training hyperparams, modify at will!
num_iter = 2000
lr = 4e-2
best_loss = 999
best_pars = None


# basic gradient descent
def gradient_descent(param: jax.Array, grads: jax.Array) -> jax.Array:
    return param - lr * grads.mean(axis=0)


# more advanced gradient descent
import optax

opt = optax.chain(
    optax.clip(1),
    optax.adamw(learning_rate=lr),
)
opt_state = opt.init(params)

### Train time!

In [139]:
for i in range(num_iter):
    sentences, sentence_labels = next(batch)
    one_hot_sentences, one_hot_sentence_labels = batch_one_hot(
        sentences, v
    ), batch_one_hot(sentence_labels, v)
    loss, grads = batched_grads(one_hot_sentences, one_hot_sentence_labels, params, h)
    valid_loss, _ = batched_grads(one_hot_valid, one_hot_valid_labels, params, h)
    loss, valid_loss = loss.mean(), valid_loss.mean()

    # gradient descent!
    params = jax.tree_map(gradient_descent, params, grads)

    ## uncomment these lines for advanced version
    # avg_grads = jax.tree_map(lambda g: g.mean(axis=0), grads)
    # updates, opt_state = opt.update(avg_grads, opt_state, params=pars)
    # pars = optax.apply_updates(pars, updates)

    if valid_loss < best_loss:
        best_pars = deepcopy(params)
        best_loss = valid_loss
    if i % 20 == 0:
        print(f"train loss: {loss.mean():.3f}", end=", ")
        print(f"valid loss: {valid_loss.mean():.3f}")

print(f"best valid loss: {best_loss:.3f}")
print(predict_next_words("Red fish ", vocab, params, h, 10, include_prompt=True))

train loss: 5.704, valid loss: 5.705
train loss: 5.638, valid loss: 5.633
train loss: 5.542, valid loss: 5.532
train loss: 5.401, valid loss: 5.391
train loss: 5.223, valid loss: 5.219
train loss: 5.040, valid loss: 5.048
train loss: 4.881, valid loss: 4.910
train loss: 4.757, valid loss: 4.816
train loss: 4.660, valid loss: 4.757
train loss: 4.585, valid loss: 4.721
train loss: 4.525, valid loss: 4.700
train loss: 4.476, valid loss: 4.689
train loss: 4.435, valid loss: 4.682
train loss: 4.400, valid loss: 4.679
train loss: 4.369, valid loss: 4.678
train loss: 4.341, valid loss: 4.678
train loss: 4.316, valid loss: 4.678
train loss: 4.293, valid loss: 4.679
train loss: 4.271, valid loss: 4.680
train loss: 4.252, valid loss: 4.681
train loss: 4.234, valid loss: 4.682
train loss: 4.218, valid loss: 4.684
train loss: 4.202, valid loss: 4.685
train loss: 4.188, valid loss: 4.687
train loss: 4.175, valid loss: 4.688
train loss: 4.162, valid loss: 4.690
train loss: 4.151, valid loss: 4.692
t

In [140]:
print(predict_next_words("Red fish ", vocab, pars, h, 10, include_prompt=True))

Red fish  | 
 
 
 
 
 
 
 
 
 

