In [180]:
import jax
import jax.numpy as jnp
from jax import grad, jit
import optax

import numpy as np
import os
from tqdm import tqdm
from string import ascii_letters, digits, punctuation, whitespace


In [19]:
alphabet = ascii_letters + digits + punctuation + whitespace

In [20]:
def get_data(n_books=100):
    c2i = {c: i for i, c in enumerate(alphabet)}
    i2c = {i: c for i, c in enumerate(alphabet)}
    PATH = os.getcwd() + '/data/books'
    files = [os.path.join(PATH, f) for f in os.listdir(PATH) if f.endswith('.txt')]
    codex = []
    for f in tqdm(files[:n_books]):
        with open(f, 'r') as file:
            txt = file.read().split('***')[2]
            txt = [c2i[char] for char in txt if char in alphabet]
            if len(txt) > 1000:
                codex += txt
    return codex, c2i, i2c

In [54]:
def get_batches(rng, data, batch_size, block_length):
    while True:
        # data is one looooooong string. Pick batch random starting points
        idxs = jax.random.randint(rng, shape=(batch_size,), minval=0, maxval=len(data) - block_length - 1)
        batch = jnp.array([data[idx:idx+block_length + 1] for idx in idxs])
        yield batch[:, :-1], batch[:, 1:]
        

In [172]:
def init_fn(rng, n_chars, n_embed, n_hidden, scale=1e-2):
    embedding = scale * jax.random.normal(rng, shape=(n_chars, n_embed))
    fc1 = {
        'w': scale * jax.random.normal(rng, shape=(n_embed, n_hidden)),
        'b': scale * jax.random.normal(rng, shape=(n_hidden,))
    }
    fc2 = {
        'w': scale * jax.random.normal(rng, shape=(n_hidden, n_chars)),
        'b': scale * jax.random.normal(rng, shape=(n_chars,))
    }
    return {'embedding': embedding, 'fc': {'fc1': fc1, 'fc2': fc2}}

def apply_fn(params, x):
    x = x.reshape(-1)
    z = params['embedding'][x]
    z = jnp.tanh(jnp.dot(z, params['fc']['fc1']['w']) + params['fc']['fc1']['b'])
    z = jnp.dot(z, params['fc']['fc2']['w']) + params['fc']['fc2']['b']
    return z

def loss_fn(params, x, y):
    y_pred = apply_fn(params, x)
    log_probs = jax.nn.log_softmax(y_pred, axis=-1)
    correct_log_probs = jnp.take_along_axis(log_probs, y.flatten()[:, None], axis=1)
    correct_log_probs = jnp.nan_to_num(correct_log_probs, nan=-1e3)
    loss = -jnp.mean(correct_log_probs)
    return loss


In [181]:
rng = jax.random.PRNGKey(0)
data, c2i, i2c = get_data() if 'data' not in globals() else eval('data, c2i, i2c')
batches = get_batches(rng, data, 32, 128)
params = init_fn(rng, len(alphabet), 64, 128)
opt = optax.adam(1e-3)
opt_state = opt.init(params)

In [182]:
for i in range(100):
    x, y = next(batches)
    loss = loss_fn(params, x, y)
    print(f'Loss: {loss:.4f}')
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(params, x, y)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

Loss: 762.0879
Loss: 762.0870
Loss: 762.0863
Loss: 762.0854
Loss: 762.0845
Loss: 762.0835
Loss: 762.0824
Loss: 762.0812
Loss: 762.0798
Loss: 762.0782
Loss: 762.0765
Loss: 762.0746
Loss: 762.0725
Loss: 762.0701
Loss: 762.0674
Loss: 762.0645
Loss: 762.0612
Loss: 762.0576
Loss: 762.0535
Loss: 762.0491
Loss: 762.0442
Loss: 762.0389
Loss: 762.0331
Loss: 762.0267
Loss: 762.0198
Loss: 762.0123
Loss: 762.0043
Loss: 761.9957
Loss: 761.9866
Loss: 761.9769
Loss: 761.9666
Loss: 761.9558
Loss: 761.9445
Loss: 761.9328
Loss: 761.9207
Loss: 761.9083
Loss: 761.8958
Loss: 761.8831
Loss: 761.8704
Loss: 761.8578
Loss: 761.8456
Loss: 761.8337
Loss: 761.8225
Loss: 761.8120
Loss: 761.8024
Loss: 761.7937
Loss: 761.7859
Loss: 761.7792
Loss: 761.7733
Loss: 761.7684
Loss: 761.7643
Loss: 761.7609
Loss: 761.7581
Loss: 761.7555
Loss: 761.7533
Loss: 761.7512
Loss: 761.7493
Loss: 761.7474
Loss: 761.7456
Loss: 761.7438
Loss: 761.7421
Loss: 761.7405
Loss: 761.7390
Loss: 761.7375
Loss: 761.7363
Loss: 761.7352
Loss: 761.