In [1]:
from tqdm import tqdm
import optax
import haiku as hk
import jax
import numpy as np
from jax import jit, grad
import jax.numpy as jnp
from functools import partial
from src.data import load_subject, make_kfolds
from src.model import loss_fn, init, apply
from src.plots import plot_brain, plot_decoding
from src.utils import CONFIG

In [2]:
subject = load_subject('subj05', image_size=CONFIG['image_size'])

In [21]:
def hyperparam_fn():
    return {
        'batch_size': np.random.choice([32, 64]),
        'n_steps': np.random.randint(low=100, high=200),
        'dropout_rate': np.random.uniform(low=0.1, high=0.5),
    }

def update_fn(params, rng, fmri, img, opt_state, opt, dropout_rate):
    rng, key = jax.random.split(rng)
    grads = grad(loss_fn)(params, key, fmri, img, dropout_rate=dropout_rate)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state


def train_loop(rng, opt, train_loader, val_loader, plot_batch, hyperparams):
    metrics = []
    rng, key = jax.random.split(rng, 2)
    lh, rh, img = next(train_loader)
    params = init(key, lh)
    opt_state = opt.init(params)
    update = partial(update_fn, opt=opt, dropout_rate=hyperparams['dropout_rate'])
    for step in tqdm(range(hyperparams['n_steps'])):
        rng, key = jax.random.split(rng)
        lh, rh, img = next(train_loader)
        params, opt_state = update(params, key, lh, img, opt_state)
        if (step % (hyperparams['n_steps'] // 100)) == 0:
            rng, key = jax.random.split(rng)
            metrics.append(evaluate(params, key, train_loader, val_loader))
            # plot_pred = apply(params, key, plot_batch[0])
            # plot_decodings(plot_pred)
    return metrics


def evaluate(params, rng, train_loader, val_loader, n_steps=10):
    # each batch is a tuple(lh, rh, img). Connect n_steps batches into 1
    train_loss, val_loss = 0, 0
    for _ in range(n_steps):
        rng, key_train, key_val = jax.random.split(rng, 3)
        lh, rh, img = next(train_loader)
        train_loss += loss_fn(params, key_train, lh, img)
        lh, rh, img = next(val_loader)
        val_loss += loss_fn(params, key_val, lh, img)
    train_loss /= n_steps
    val_loss /= n_steps
    return(f'train_loss: {train_loss}, val_loss: {val_loss}')


def train_folds(kfolds, hyperparams, seed=0):
    metrics = {}
    rng = jax.random.PRNGKey(seed)
    opt = optax.lion(1e-3)
    plot_batch = None
    for idx, (train_loader, val_loader) in enumerate(kfolds):
        plot_batch = next(train_loader) if plot_batch is None else plot_batch
        rng, key = jax.random.split(rng)
        fold_metrics = train_loop(key, opt, train_loader, val_loader, plot_batch, hyperparams)
        metrics[idx] = fold_metrics
        return metrics

In [22]:
hyperparams = hyperparam_fn()
kfolds = make_kfolds(subject, hyperparams)
train_folds(kfolds, hyperparams)
# train_loader, _ = next(kfolds)

 23%|██▎       | 44/191 [00:06<00:18,  7.91it/s]

In [10]:
rng = jax.random.PRNGKey(0)
rng, key = jax.random.split(rng)
lh, rh, img = next(train_loader)

In [11]:
params = init(key, lh)

In [12]:
apply(params, key, lh)

Array([[[[-3.16308439e-02, -3.31890076e-01, -5.12075238e-02],
         [ 4.85625491e-02,  2.17759639e-01,  5.90159558e-02],
         [-2.78418064e-01, -6.86557889e-02,  2.25962013e-01],
         ...,
         [-9.51540470e-02, -2.32540250e-01, -4.32052374e-01],
         [-2.11033791e-01, -2.48100042e-01,  4.59478647e-01],
         [ 3.46645892e-01, -2.35105738e-01, -1.11824542e-01]],

        [[-2.35176325e-01, -1.32810012e-01, -9.88911763e-02],
         [-4.55371052e-01,  1.92810908e-01, -2.92965651e-01],
         [-6.06114268e-02, -4.25028712e-01,  1.31005749e-01],
         ...,
         [-1.79364175e-01,  9.62745026e-02,  4.54996139e-01],
         [-8.48789662e-02, -7.97832534e-02, -1.35977253e-01],
         [-9.35600474e-02,  1.97868705e-01,  3.33004594e-01]],

        [[ 1.24057353e-01, -1.16122976e-01,  2.40907833e-01],
         [ 2.87088841e-01, -2.07734168e-01, -6.42771600e-03],
         [ 5.59472516e-02,  2.47780457e-02, -2.08534971e-02],
         ...,
         [ 2.19584569e-0