In [20]:
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, jaxtyped, Int
from beartype import beartype as typechecker
# from typeguard import typechecked as typechecker
from equinox import Module as JaxClass


@jaxtyped
@typechecker
class Parameters(JaxClass):
    embedding_weights: Float[Array, "hidden_state embedding"]
    hidden_state_weights: Float[Array, "hidden_state hidden_state"]
    output_weights: Float[Array, "embedding hidden_state"]
    hidden_state_bias: Float[Array, "hidden_state"]
    output_bias: Float[Array, "embedding"]

e = 3
h = 10

init_pars = Parameters(jnp.ones((h,e)), jnp.ones((h,h)), jnp.ones((e,h)), jnp.ones((h,)), jnp.ones((e,)))
init_state = jnp.ones((h,))
random_embed = jnp.zeros((e,)).at[3].set(1)  # technically random one-hot, need the embedding matrix...

@jax.jit
@jaxtyped
@typechecker
def update_hidden_state(
    embedding: Float[Array, "embedding"], 
    hidden_state: Float[Array, "hidden_state"], 
    params: Parameters
) -> Float[Array, "hidden_state"]:
    return jax.nn.relu(params.hidden_state_weights @ hidden_state + params.embedding_weights @ embedding + params.hidden_state_bias)

update_hidden_state(random_embed, init_state, init_pars)

Array([11., 11., 11., 11., 11., 11., 11., 11., 11., 11.], dtype=float32)

In [30]:
@jax.jit
@jaxtyped
@typechecker
def output(
    hidden_state: Float[Array, "hidden_state"], 
    params: Parameters
) -> Float[Array, "embedding"]:
    return jax.nn.softmax(params.output_weights @ hidden_state + params.output_bias)


@jax.jit
@jaxtyped
@typechecker
def loss(
    output: Float[Array, "embedding"],
    next_embedding: Float[Array, "embedding"] 
) -> float:
    # index the softmax probs at the word of interest
    return -jnp.log(output[next_embedding.astype("bool")])[0]

loss(jnp.array([0,0.2,0.8]), jnp.array([0,1,0]))

BeartypeCallHintParamViolation: Function __main__.loss() parameter next_embedding="Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>" violates type hint <class 'jaxtyping.Float[Array, 'embedding']'>, as <protocol "jax._src.interpreters.partial_eval.DynamicJaxprTracer"> "Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=1/0)>" not instance of <class "jaxtyping.Float[Array, 'embedding']">.

In [18]:
jnp.array([2,3,4,5])[jnp.array([0,0,1,0]).astype("bool")][0]

Array(4, dtype=int32)

In [72]:
import re
file_name = 'bee-movie-names.txt'

with open(file_name, 'r+') as file:
    all_text = file.read()
    all_text = all_text.replace('\n', ' ').replace('  : ', '')
    
# Define a regular expression pattern to match all punctuation marks
punctuation_pattern = r'[^\w\s]'

# Define a regular expression pattern to match words with apostrophes
apostrophe_pattern = r'\w+(?:\'\w+)?'

# Combine the two patterns to match all tokens
token_pattern = punctuation_pattern + '|' + apostrophe_pattern

# Split the text into tokens, including words with apostrophes as separate tokens
all_words = re.findall(token_pattern, all_text.lower())
vocab = list(set(all_words))

sentence_length = 5

vocab_one_hot_indicies = jnp.array([vocab.index(t) for t in all_words], dtype=jnp.int32)
split_indicies = vocab_one_hot_indicies[:(len(vocab)//5)*5].reshape(len(vocab)//5,5)

sentence = split_indicies[0]

def one_hot_sentence(sentence: Int[Array, "sentence"], vocab_size: int) -> Int[Array, "sentence vocab"]:
    return jnp.array([jnp.zeros((vocab_size,)).at[word].set(1) for word in sentence])

Array([[ 465, 1934, 2130,  767,  966],
       [ 618,  759,  987,  464,  704],
       [ 558,  459, 1406,  991, 2187],
       ...,
       [ 767, 1078,  735, 1987,  926],
       [ 704, 1876,  427, 1783, 2403],
       [1934,  624,  914,  558, 1966]], dtype=int32)