# MIIII

In [9]:
import jax, optax, yaml
import jax.numpy as jnp
from jax import jit, random, tree_util, vmap, grad, value_and_grad
from jax.lib import xla_client
from functional import partial
from tqdm import tqdm
import tikz
from typing import Any, Tuple, List, Dict, Iterator
from oeis import A000032


from src import args_fn, apply_fn, init_fn
import esch

In [6]:
import oeis as OEIS

In [28]:
def loss_fn(params, x, y):  # todo: weight by prime frquency
    y_pred = apply_fn(params, x)
    loss = -jnp.mean(y * jnp.log(y_pred) + (1 - y) * jnp.log(1 - y_pred))
    return loss

In [29]:
@jit
def update_fn(params, state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, state = opt.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    return params, state, loss

In [30]:
pbar = tqdm(range(conf["epochs"]))
losses = jnp.zeros(conf["epochs"])
for epoch in pbar:
    params, state, loss = update_fn(params, state, x, y)
    losses = losses.at[epoch].add(loss)

100%|██████████| 1000/1000 [00:15<00:00, 66.29it/s]


In [31]:
info = dict(title="Training curves", xlab="Epoch", ylab="MSE")
fig = esch.curves_fn([losses, losses - 0.01, losses + 0.01, losses + 0.02], info)
fig.show()

In [9]:
y.shape

(539,)

In [48]:
def spiral_fn(v):
    # floor sqrt of length of v
    n = jnp.sqrt(v.shape[0]).astype(jnp.int32)
    v, lst = v[: n * n].reshape(n, n), [v[:n]]
    print(lst)
    while v.size:
        v = jnp.rot90(v[1:], 1)
        lst.append(v[0])
    return jnp.concatenate(lst).reshape((n, n))

In [49]:
spiral_fn(jnp.arange(100))

[Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)]


Array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
       [19, 29, 39, 49, 59, 69, 79, 89, 99, 98],
       [97, 96, 95, 94, 93, 92, 91, 90, 80, 70],
       [60, 50, 40, 30, 20, 10, 11, 12, 13, 14],
       [15, 16, 17, 18, 28, 38, 48, 58, 68, 78],
       [88, 87, 86, 85, 84, 83, 82, 81, 71, 61],
       [51, 41, 31, 21, 22, 23, 24, 25, 26, 27],
       [37, 47, 57, 67, 77, 76, 75, 74, 73, 72],
       [62, 52, 42, 32, 33, 34, 35, 36, 46, 56],
       [66, 65, 64, 63, 53, 43, 44, 45, 55, 54]], dtype=int32)