In [1]:
from tqdm import tqdm
import optax
import haiku as hk
import jax
import numpy as np
from jax import jit, grad
import jax.numpy as jnp
from functools import partial
from src.data import load_subject, make_kfolds
from src.model import network_fn, loss_fn
from src.plots import plot_brain, plot_decodings

In [2]:
args = {'image_size': 32}
subject = load_subject('subj05', image_size=args['image_size'])
init, apply = hk.transform(partial(network_fn, image_size=args['image_size']))

In [7]:
def hyperparam_fn():
    return {
        'batch_size': np.random.choice([32, 64]),
        'n_steps': np.random.randint(low=100, high=200),
        'dropout_rate': np.random.uniform(low=0.1, high=0.5),
    }

def update_fn(params, fmri, img, opt_state, rng, opt):
    rng, key = jax.random.split(rng)
    grads = grad(loss_fn)(params, key, fmri, img)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state


def train_loop(opt, train_loader, val_loader, plot_batch, hyperparams, rng):
    rng, key = jax.random.split(rng)
    lh, rh, img = next(train_loader)
    params = init(key, lh)
    opt_state = opt.init(params)
    update = partial(update_fn, opt=opt)
    for step in tqdm(range(hyperparams['n_steps'])):
        rng, key = jax.random.split(rng)
        lh, rh, img = next(train_loader)
        params, opt_state = update(params, lh, img, opt_state, key)
        if (step % (hyperparams['n_steps'] // 100)) == 0:
            evaluate(params, train_loader, val_loader)
            # plot_decodings(apply(params, key, plot_batch[0]), plot_batch[2])
    return params

def evaluate(params, train_loader, val_loader, n_steps=4):
    pass


def train_folds(kfolds, hyperparams, seed=0):
    rng = jax.random.PRNGKey(seed)
    opt = optax.lion(1e-3)
    plot_batch = None
    for train_loader, val_loader in kfolds:
        plot_batch = next(train_loader) if plot_batch is None else plot_batch
        rng, key = jax.random.split(rng)
        params = train_loop(opt, train_loader, val_loader, plot_batch, hyperparams, key)
        return

In [8]:
hyperparams = hyperparam_fn()
kfolds = make_kfolds(subject, hyperparams)
train_folds(kfolds, hyperparams)

100%|██████████| 158/158 [00:05<00:00, 26.68it/s]
