# MIIII

In [None]:
import jax.numpy as jnp
from jax import random
import jax
from oeis import oeis
import optax
from functools import partial
from sklearn.metrics import confusion_matrix, f1_score
from matplotlib import pyplot as plt
import miiii

In [None]:
def predict(apply_fn, params, x, _=random.PRNGKey(0)):
    return (jax.nn.sigmoid(apply_fn(params, _, x, 0.0)) > 0.5).astype(jnp.int32)

In [None]:
# config and init
conf, (rng, key) = miiii.get_conf(), random.split(random.PRNGKey(0))

data = miiii.prime_fn(conf.n, partial(miiii.base_n, conf.base))
params = miiii.init_fn(key, conf)

In [None]:
# setup and train
apply_fn = miiii.make_apply_fn(miiii.vaswani_fn)
train_fn, opt_state = miiii.init_train(apply_fn, params, conf, *data)
(params, opt_state), losses = train_fn(conf.epochs, rng, (params, opt_state))

In [None]:
# evaluate
def make_plots(losses, conf, params):
    (train_x, train_y), (valid_x, valid_y) = data
    train_pred = predict(apply_fn, params, train_x)
    valid_pred = predict(apply_fn, params, valid_x)
    miiii.polar_plot(train_y, train_pred, conf, "train")
    miiii.polar_plot(valid_y, valid_pred, conf, "valid", offset=train_y.shape[0])
    miiii.curve_plot(losses, conf, params)


make_plots(losses, conf, params)