In [1]:
from typing import List, Tuple
import random

import jax
import numpy as np

In [2]:
def load_data(path: str) -> Tuple[List[str], List[str]]:

    with open(path, 'r') as f:
        data = f.read()

    words = data.splitlines()
    words = [word.strip() for word in words] # Remove leading/trailing whitespace
    words = [word for word in words if word] # Remove empty strings

    vocab = sorted(list(set(''.join(words))))
    vocab = ['<eos>'] + vocab
    print(f"number of examples in dataset: {len(words)}")
    print(f"max word length: {max([len(word) for word in words])}")
    print(f"min word length: {min([len(word) for word in words])}")
    print(f"unique characters in dataset: {len(vocab)}")
    print("vocabulary:")
    print(' '.join(vocab))
    print('example for a word:')
    print(words[0])
    return words, vocab

words, vocab = load_data('names.txt')

number of examples in dataset: 32033
max word length: 15
min word length: 2
unique characters in dataset: 27
vocabulary:
<eos> a b c d e f g h i j k l m n o p q r s t u v w x y z
example for a word:
emma


In [3]:
def encode(word: str, vocab: List[str]) -> List[int]:
    """
    Encode a word, add <eos> at the beginning and the end of the word.
    """
    return [vocab.index('<eos>')] + [vocab.index(char) for char in word] + [vocab.index('<eos>')]

def decode(indices: List[int], vocab: List[str]) -> str:
    """
    Decode a list of indices to a word using the vocabulary.
    """
    return ''.join([vocab[index] for index in indices])

for i in range(5):
    print(f"word: {words[i]}")
    print(f"encoded: {encode(words[i], vocab)}")
    print(f"decoded: {decode(encode(words[i], vocab), vocab)}")
    print()

word: emma
encoded: [0, 5, 13, 13, 1, 0]
decoded: <eos>emma<eos>

word: olivia
encoded: [0, 15, 12, 9, 22, 9, 1, 0]
decoded: <eos>olivia<eos>

word: ava
encoded: [0, 1, 22, 1, 0]
decoded: <eos>ava<eos>

word: isabella
encoded: [0, 9, 19, 1, 2, 5, 12, 12, 1, 0]
decoded: <eos>isabella<eos>

word: sophia
encoded: [0, 19, 15, 16, 8, 9, 1, 0]
decoded: <eos>sophia<eos>



In [4]:
encoded_words = [encode(word, vocab) for word in words]
print(encoded_words[0])
print(decode(encoded_words[0], vocab))
print(len(encoded_words))
print(len(encoded_words[0]))

[0, 5, 13, 13, 1, 0]
<eos>emma<eos>
32033
6


In [5]:
def get_dataset(encoded_words: List[List[int]]) -> Tuple[jax.Array, jax.Array]:
    """
    Convert a list of encoded words to a list of bigrams.
    """
    X = []
    y = []
    for word in encoded_words:
        for char1, char2 in zip(word[:-1], word[1:]):
            X.append(char1)
            y.append(char2)
    return jax.numpy.array(X), jax.numpy.array(y)

def get_train_val_test(encoded_words: List[List[int]]) -> Tuple[jax.Array, jax.Array, jax.Array, jax.Array, jax.Array, jax.Array]:
    """
    Split the dataset into training, validation and test sets.
    """
    random.shuffle(encoded_words)
    train_words = encoded_words[:int(0.8*len(encoded_words))]
    val_words = encoded_words[int(0.8*len(encoded_words)):int(0.9*len(encoded_words))]
    test_words = encoded_words[int(0.9*len(encoded_words)):]
    X_train, y_train = get_dataset(train_words)
    X_val, y_val = get_dataset(val_words)
    X_test, y_test = get_dataset(test_words)
    return X_train, y_train, X_val, y_val, X_test, y_test

X_train, y_train, X_val, y_val, X_test, y_test = get_train_val_test(encoded_words)

In [6]:
X_train.shape, y_train.shape, X_val.shape, y_val.shape, X_test.shape, y_test.shape

((182466,), (182466,), (22838,), (22838,), (22842,), (22842,))

In [29]:
EMBEDDING_DIM = 128
W = np.random.randn(len(vocab), EMBEDDING_DIM)
W.shape

(27, 128)

In [32]:
(jax.nn.one_hot(X_train, len(vocab)) @ W)[0, :10]

Array([ 0.86328125, -1.015625  , -0.8359375 ,  0.26953125,  0.01367188,
        0.12109375,  0.51953125, -0.671875  ,  1.234375  , -0.0546875 ],      dtype=float32)

In [33]:
W[X_train][0, :10]

array([ 0.86347077, -1.01740871, -0.83740075,  0.26996182,  0.01365725,
        0.12102629,  0.5200092 , -0.67272471,  1.23825448, -0.05461122])

In [26]:
X_train

Array([ 0,  7, 15, ...,  5, 13, 25], dtype=int32)