In [1]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import jit, random, tree_util
from jax.lib import xla_client
from functional import partial
from tqdm import tqdm
import numpy as np


from src.data import data_fn
from src.model import init_fn, apply_fn
import esch

In [2]:
with open('config.yaml', 'r') as f:
    conf = yaml.load(f, Loader=yaml.FullLoader)

In [3]:
rng, key = random.split(random.PRNGKey(0))
opt      = optax.lion(conf['lr'], weight_decay=conf['weight_decay'])
params   = init_fn(key, conf)
state    = opt.init(params)
x, y     = data_fn(conf)

In [4]:
def loss_fn(params, x, y):  # todo: weight by prime frquency
    y_pred = apply_fn(params, x)
    loss   = jnp.square(y_pred - y).mean()
    return loss

In [5]:
@jit
def update_fn(params, state, x, y):
    loss, grads    = jax.value_and_grad(loss_fn)(params, x, y)
    updates, state = opt.update(grads, state, params)
    params         = optax.apply_updates(params, updates)
    return params, state, loss

In [6]:
pbar   = tqdm(range(conf['epochs']))
losses = jnp.zeros(conf['epochs'])
for epoch in pbar:
    params, state, loss = update_fn(params, state, x, y)
    losses              = losses.at[epoch].add(loss)

100%|██████████| 20/20 [00:00<00:00, 23.41it/s]


In [7]:
info = dict(title='Training curves', xlab='Epoch', ylab='MSE')
fig  = esch.curves_fn([losses, losses - 0.01, losses + 0.01, losses+ 0.02], info)
fig.show()

In [8]:
esch.compute_graph_fn(loss_fn, params, x, y)

Parameter 20</b><br/>sharding={replicated}<br/>s32[541]
Parameter 16</b><br/>sharding={replicated}<br/>f32[64,1]{1,0}
Parameter 13</b><br/>sharding={replicated}<br/>f32[64]
Parameter 12</b><br/>sharding={replicated}<br/>f32[64]
Parameter 5</b><br/>sharding={replicated}<br/>f32[64]
Parameter 4</b><br/>sharding={replicated}<br/>f32[64]
Parameter 18</b><br/>sharding={replicated}<br/>f32[16,64]{1,0}
Parameter 19</b><br/>sharding={replicated}<br/>s32[541,3]{1,0}
Parameter 0</b><br/>f32[16,64]{1,0}
Parameter 1</b><br/>s32[541,3]{1,0}
compare.0</b><br/>direction=LT<br/>pred[541,3]{1,0}<br/>
add.43</b><br/>s32[541,3]{1,0}<br/>
select.0</b><br/>s32[541,3]{1,0}
bitcast.35</b><br/>s32[541,3,1]{2,1,0}
gather.1</b><br/>offset_dims={2}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=2<br/>slice_sizes={1,64}<br/>f32[541,3,64]{2,1,0}
Parameter 17</b><br/>sharding={replicated}<br/>f32[128,64]{1,0}
Parameter 0</b><br/>f32[128,64]{1,0}
iota.1</b><br/>iota_dimension=0<br/>s32[3,1]{1,0}
ga