# MIIII

In [1]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import jit, random, tree_util, vmap, grad, value_and_grad
from jax.lib import xla_client
from functional import partial
from tqdm import tqdm
import tikz


from src import args_fn, apply_fn, init_fn, data_fn, conf_fn
import esch

In [2]:
conf = conf_fn()
rng, key = random.split(random.PRNGKey(0))
opt = optax.lion(conf["lr"], weight_decay=conf["weight_decay"])
params = init_fn(key, conf)
state = opt.init(params)
x, y = data_fn(conf)

In [3]:
def loss_fn(params, x, y):  # todo: weight by prime frquency
    y_pred = apply_fn(params, x)
    loss = -jnp.mean(y * jnp.log(y_pred) + (1 - y) * jnp.log(1 - y_pred))
    return loss

In [4]:
@jit
def update_fn(params, state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, state = opt.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    return params, state, loss

In [5]:
pbar = tqdm(range(conf["epochs"]))
losses = jnp.zeros(conf["epochs"])
for epoch in pbar:
    params, state, loss = update_fn(params, state, x, y)
    losses = losses.at[epoch].add(loss)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:13<00:00, 73.61it/s]


In [6]:
info = dict(title="Training curves", xlab="Epoch", ylab="MSE")
fig = esch.curves_fn([losses, losses - 0.01, losses + 0.01, losses + 0.02], info)
fig.show()

In [7]:
losses

Array([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
       nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na