# Neuroscape playground

In [None]:
import jax
from jax import grad, jit
import jax.numpy as jnp
import numpy as np
import haiku as hk
import optax
import wandb
from typing import List, Tuple, Dict
from functools import partial
from tqdm import tqdm
from src.model import lh_init, rh_init, loss_fns
from src.eval import evaluate
from src.data import get_data
from src.utils import get_args_and_config

In [None]:
Fold = Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Batch = Fold
lh_train_loss_fn, rh_train_loss_fn = loss_fns['lh']
lh_infer_loss_fn, rh_infer_loss_fn = loss_fns['rh']

## data

In [None]:
opt = optax.adam(1e-3)
args, config = get_args_and_config()
data = get_data(args, config)
rng = hk.PRNGSequence(jax.random.PRNGKey(42))

In [None]:
# functions
def train(data, config):
    """train function"""
    config['group_name'] = wandb.util.generate_id()
    for subject, (folds, _) in data.items():
        lh_params_lst, rh_params_lst = train_subject(folds, config)
        np.save(f'results/models/{subject}_lh.npy', lh_params_lst)
        np.save(f'results/models/{subject}_rh.npy', rh_params_lst)


def train_subject(folds: List[Fold], config: Dict) -> Tuple[List[hk.Params], List[hk.Params]]:
    """train function"""
    # TODO: parallelize using pmap or vmap
    data = [make_fold(folds, fold) for fold in range(len(folds))]  # (train_data, val_data) list
    train_fold = partial(train_fold_fn, config=config)
    lh_params_lst, rh_params_lst = [], []
    for hem, init_fn in zip(['lh', 'rh'], [lh_init, rh_init]):
        config['hem'] = hem
        train_loss_fn, val_loss_fn = loss_fns(hem)
        params_lst = [init_fn(next(rng), img) for img, _, _, _ in folds]
        for idx, (params, fold) in enumerate(zip(params_lst, data)):
            params = train_fold(params, fold, train_loss_fn, val_loss_fn)
            params_lst[idx] = params
        if hem == 'lh':
            lh_params_lst = params_lst
        else:
            rh_params_lst = params_lst
    return lh_params_lst, rh_params_lst


def train_fold_fn(params, fold, config: Dict, train_loss_fn, val_loss_fn, hem) -> hk.Params:
    """train_fold function"""
    train_data, val_data = fold
    wandb.init(project="neuroscope", entity='syrkis', config=config, group=config['group_name'])
    for step in tqdm(range(config['n_steps'])):
        batch = get_batch(train_data, config['batch_size'])
        params, opt_state = lh_update(params, batch, opt_state) if hem == 'lh' else rh_update(params, batch, opt_state)
        if step % (config['n_steps'] // 100) == 0:
            metrics = evaluate(params, train_data, val_data, get_batch, config, train_loss_fn, val_loss_fn)
            wandb.log(metrics, step=step)
    wandb.finish()
    return params


def get_batch(fold: Fold, batch_size: int, hem: str) -> Batch:
    """get a batch from a split"""
    img, cat, lh, rh = fold
    idx = np.random.randint(0, img.shape[0], batch_size)
    fmri = lh if hem == 'lh' else rh
    return img[idx], cat[idx], fmri[idx]


def make_fold(folds: List[Fold], fold: int) -> Batch:
    """make a fold from a list of folds"""
    train_imgs = [f[0] for f in folds[:fold] + folds[fold + 1:]]
    train_cats = [f[1] for f in folds[:fold] + folds[fold + 1:]]
    train_fmris = [f[2] for f in folds[:fold] + folds[fold + 1:]]
    train_data = tuple(map(jnp.concatenate, [train_imgs, train_cats, train_fmris]))
    return train_data, folds[fold]


def lh_update(params: hk.Params, batch: Batch, opt_state: optax.OptState) -> Tuple[hk.Params, optax.OptState]:
    grads = grad(lh_train_loss_fn)(params, batch)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

def rh_update(params: hk.Params, batch: Batch, opt_state: optax.OptState) -> Tuple[hk.Params, optax.OptState]:
    grads = grad(rh_train_loss_fn)(params, batch)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

In [None]:
train(data, config)