# Neuroscape playground

In [None]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax
from tqdm import tqdm
import haiku as hk
import numpy as np
import wandb
from functools import partial
from src.data import get_data
from src.utils import get_args_and_config
from src.fmri import plot_brain
from src.model import network_fn, mse, focal_loss, loss_fn_base

## Train with final hyperparameters (and alpha and beta set to 0 (baseline))

In [None]:
args, _ = get_args_and_config()
data = get_data(args)

In [None]:
opt = optax.adamw(0.001)  # perhaps hyper param search for lr and weight decay

In [None]:
config = {
    'alpha': 0,
    'beta': 0,
    'n_steps': 6000,
    'batch_size': 32,
    'n_units': 100,
    'n_layers': 2,
    'latent_dim': 100,
    'dropout': 0.15,
    }

In [None]:

forward = hk.transform(partial(network_fn, config=config))
rng = jax.random.PRNGKey(42)


def loss_fn(params, rng, batch, hem, config):
    alpha = config['alpha']
    beta = config['beta']
    x, lh, rh, cat = batch
    lh_hat, rh_hat, cat_hat = forward.apply(params, rng, x)
    lh_loss = mse(lh_hat, lh)
    rh_loss = mse(rh_hat, rh)
    cat_loss = focal_loss(cat_hat, cat)
    hem_loss = lh_loss if hem == 'lh' else rh_loss
    not_hem_loss = rh_loss if hem == 'lh' else lh_loss
    fmri_loss = (1 - beta) * hem_loss + beta * not_hem_loss
    loss = (1 - alpha) * fmri_loss + alpha * cat_loss
    return loss

lh_loss_fn = jit(partial(loss_fn, hem='lh', config=config))
rh_loss_fn = jit(partial(loss_fn, hem='rh', config=config))

In [None]:
def get_fold(fold, fold_idx):
    train_data = [fold for idx, fold in enumerate(fold) if idx != fold_idx]
    train_data = list(map(jnp.vstack, zip(*train_data)))
    val_data = fold[fold_idx]
    return train_data, val_data

def get_batch(data, batch_size):
    while True:
        perm = np.random.permutation(data[0].shape[0])
        for i in range(0, data[0].shape[0], batch_size):
            idx = perm[i:i + batch_size]
            x = data[0][idx]
            lh = data[1][idx]
            rh = data[2][idx]
            cat = data[3][idx]
            yield x, lh, rh, cat
            
def train(data, config):
    group = wandb.util.generate_id()
    for subject, (folds, test_data) in data.items():
        train_data = list(map(jnp.vstack, zip(*folds)))
        with wandb.init(project="neuroscope", entity='syrkis', group=group, config=config) as run:
            train_fold(train_data, test_data, config)

def train_fold(train_data, val_data, config):
    lh_params = forward.init(jax.random.PRNGKey(42), jnp.ones((1, 100)))
    rh_params = forward.init(jax.random.PRNGKey(42), jnp.ones((1, 100)))
    lh_opt_state = opt.init(lh_params)
    rh_opt_state = opt.init(rh_params)
    train_batches = get_batch(train_data, config['batch_size'])
    val_batches = get_batch(val_data, config['batch_size'])
    for step in tqdm(range(config['n_steps'])):
        train_batch = next(train_batches)
        lh_params, lh_opt_state = lh_update(lh_params, train_batch, lh_opt_state)
        rh_params, rh_opt_state = rh_update(rh_params, train_batch, rh_opt_state)
        if step % (config['n_steps'] // 100) == 0:
            metrics = evaluate(lh_params, rh_params, train_batches, val_batches)
            wandb.log(metrics)
    metrics = evaluate(lh_params, rh_params, train_batches, val_batches, steps=50)
    wandb.finish()

def evaluate(lh_params, rh_params, train_batches, val_batches, steps=3):
    train_metrics = evaluate_fold(lh_params, rh_params, train_batches, steps)
    val_metrics = evaluate_fold(lh_params, rh_params, val_batches, steps, training=False)
    metrics = {f'train_{k}': v for k, v in train_metrics.items()}
    metrics.update({f'val_{k}': v for k, v in val_metrics.items()})
    return metrics

def evaluate_fold(lh_params, rh_params, batches, steps, training=True):
    metrics = {}
    for i in range(steps):
        batch = next(batches)
        batch_metrics = evaluate_batch(lh_params, rh_params, batch, training)
        metrics = {k: metrics.get(k, 0) + v for k, v in batch_metrics.items()}
    metrics = {k: v / steps for k, v in metrics.items()}
    return metrics
    

def evaluate_batch(lh_params, rh_params, batch, training):
    metrics = {}
    for hem, params in zip(['lh', 'rh'], [lh_params, rh_params]):
        mse, corr = evaluate_hem(params, batch, hem, training)
        metrics[f'{hem}_mse'] = mse
        metrics[f'{hem}_corr'] = corr
    return metrics

def evaluate_hem(params, batch, hem, training):
    x, lh, rh, _ = batch
    lh_hat, rh_hat, _ = forward.apply(params, rng, x, training=training)
    mse = jnp.mean((lh_hat - lh) ** 2) if hem == 'lh' else jnp.mean((rh_hat - rh) ** 2)
    # compute the median collumn wise correlation
    corr = pearsonr(lh_hat, lh) if hem == 'lh' else pearsonr(rh_hat, rh)
    return mse, jnp.median(corr)


# function for computing pearson's correlation coefficient for each voxel of a subject's fMRI data
def pearsonr(pred, target):
    def _pearsonr(x, y):
        corr = jnp.corrcoef(x, y)
        return corr[0, 1]
    hem_corr = vmap(_pearsonr)(pred.T, target.T)
    return hem_corr



def update(params, batch, opt_state, hem):
    """update function"""
    loss_fn = lh_loss_fn if hem == 'lh' else rh_loss_fn
    grads = grad(loss_fn)(params, rng, batch)
    updates, opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

lh_update = jit(partial(update, hem='lh'))
rh_update = jit(partial(update, hem='rh'))

In [None]:
train(data, config)