In [1]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import jit, random, tree_util
from jax.lib import xla_client
from functional import partial
from tqdm import tqdm
import numpy as np


from src.data import data_fn
from src.model import init_fn, apply_fn
import esch

In [2]:
with open('config.yaml', 'r') as f:
    conf = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
rng, key = random.split(random.PRNGKey(0))
opt      = optax.lion(conf['lr'], weight_decay=conf['weight_decay'])
params   = init_fn(key, conf)
state    = opt.init(params)
x, y     = data_fn(conf)

In [4]:
def loss_fn(params, x, y):  # todo: weight by prime frquency
    y_pred = apply_fn(params, x)
    loss   = jnp.square(y_pred - y).mean()
    return loss

In [5]:
@jit
def update_fn(params, state, x, y):
    loss, grads    = jax.value_and_grad(loss_fn)(params, x, y)
    updates, state = opt.update(grads, state, params)
    params         = optax.apply_updates(params, updates)
    return params, state, loss

In [6]:
pbar   = tqdm(range(conf['epochs']))
losses = jnp.zeros(conf['epochs'])
for epoch in pbar:
    params, state, loss = update_fn(params, state, x, y)
    losses              = losses.at[epoch].add(loss)

  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:00<00:00, 36.51it/s]


In [7]:
info = dict(title='Training curves', xlab='Epoch', ylab='MSE')
fig  = esch.curves_fn([losses, losses - 0.01, losses + 0.01, losses+ 0.02], info)
fig.show()