In [25]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import grad, jit, vmap, random
from jax.tree_util import tree_flatten

from einops import rearrange
from functional import partial

import numpy as np
import yaml
import numpy as np
from tqdm import tqdm
from src.data import data_fn
from src.model import init_fn, apply_fn

In [26]:
with open('config.yaml', 'r') as f:
    conf = yaml.load(f, Loader=yaml.FullLoader)

In [27]:
rng, key = random.split(random.PRNGKey(0))
opt      = optax.lion(conf['lr'], weight_decay=conf['weight_decay'])
params   = init_fn(key, conf)
state    = opt.init(params)
x, y     = data_fn(conf)

In [28]:
def loss_fn(params, x, y):  # todo: weight by prime frquency
    y_pred = apply_fn(params, x)
    return jnp.square(y_pred - y).mean()

In [29]:
@jit
def update_fn(params, x, y, state):
    loss, grads    = jax.value_and_grad(loss_fn)(params, x, y)
    updates, state = opt.update(grads, state, params)
    params         = optax.apply_updates(params, updates)
    return params, state, loss

In [31]:
pbar = tqdm(range(conf['epochs']))
for epoch in pbar:
    params, state, loss = update_fn(params, x, y, state)
    pbar.set_description(loss.item().round(2))
    print(loss)

 20%|██        | 2/10 [00:00<00:01,  5.86it/s]

0.115819216
0.114141785


 40%|████      | 4/10 [00:00<00:00,  6.28it/s]

0.1125456
0.111201555


 60%|██████    | 6/10 [00:00<00:00,  6.08it/s]

0.11036402
0.110401034


 80%|████████  | 8/10 [00:01<00:00,  6.27it/s]

0.11144861
0.11053997


100%|██████████| 10/10 [00:01<00:00,  6.18it/s]

0.110624015
0.11054166



