# Neuroscape playground

In [1]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import optax
from tqdm import tqdm
import haiku as hk
import numpy as np
import wandb
from functools import partial
from src.data import get_data
from src.utils import get_args_and_config
from src.fmri import plot_brain

100%|██████████| 6/6 [00:01<00:00,  3.42it/s]
100%|██████████| 6/6 [00:09<00:00,  1.54s/it]


## MANGO

In [2]:
args, _ = get_args_and_config()
data = get_data(args)

100%|██████████| 6/6 [00:53<00:00,  8.85s/it]


In [5]:
opt = optax.adamw(0.001)  # perhaps hyper param search for lr and weight decay

In [11]:
def forward(x):
    x_mlp = hk.Sequential([
        hk.nets.MLP([100] * 1, activation=jnp.tanh),
        hk.Linear(100),
    ])
    lh_ml = hk.Sequential([
        hk.nets.MLP([100] * 1, activation=jnp.tanh),
        hk.Linear(19004),
    ])
    rh_ml = hk.Sequential([
        hk.nets.MLP([100] * 1, activation=jnp.tanh),
        hk.Linear(20544),
    ])
    x = x_mlp(x)
    lh_hat = lh_ml(x)
    rh_hat = rh_ml(x)
    return lh_hat, rh_hat

init_fn, apply_fn = hk.without_apply_rng(hk.transform(forward))


def loss_fn(params, batch, hem):
    x, lh, rh = batch
    lh_hat, rh_hat = apply_fn(params, x)
    return jnp.mean((lh_hat - lh) ** 2) if hem == 'lh' else jnp.mean((rh_hat - rh) ** 2)

lh_loss_fn = jit(partial(loss_fn, hem='lh'))
rh_loss_fn = jit(partial(loss_fn, hem='rh'))

In [15]:
def get_fold(fold, fold_idx):
    train_data = [fold for idx, fold in enumerate(fold) if idx != fold_idx]
    train_data = list(map(jnp.vstack, zip(*train_data)))
    val_data = fold[fold_idx]
    return train_data, val_data

def get_batch(data, batch_size):
    while True:
        perm = np.random.permutation(data[0].shape[0])
        for i in range(0, data[0].shape[0], batch_size):
            idx = perm[i:i + batch_size]
            # x = data[0][idx]
            x = jnp.concatenate([data[0][idx], data[3][idx]], axis=1)
            lh = data[1][idx]
            rh = data[2][idx]
            yield x, lh, rh
            
def train(model, data, config):
    group = wandb.util.generate_id()
    for subject, (folds, test_data) in data.items():
        train_data = list(map(jnp.vstack, zip(*folds)))
        with wandb.init(project="neuroscope", entity='syrkis', group=group) as run:
            train_fold(model, train_data, test_data, config)

def train_fold(model, train_data, val_data, config):
    lh_params = init_fn(jax.random.PRNGKey(42), jnp.ones((1, 180)))
    rh_params = init_fn(jax.random.PRNGKey(42), jnp.ones((1, 180)))
    lh_opt_state = opt.init(lh_params)
    rh_opt_state = opt.init(rh_params)
    train_batches = get_batch(train_data, config['batch_size'])
    val_batches = get_batch(val_data, config['batch_size'])
    for step in tqdm(range(config['n_steps'])):
        train_batch = next(train_batches)
        lh_params, lh_opt_state = lh_update(lh_params, train_batch, lh_opt_state)
        rh_params, rh_opt_state = rh_update(rh_params, train_batch, rh_opt_state)
        if step % (config['n_steps'] // 100) == 0:
            metrics = evaluate(lh_params, rh_params, train_batches, val_batches)
            wandb.log(metrics)
    metrics = evaluate(lh_params, rh_params, train_batches, val_batches, steps=50)
    wandb.finish()

def evaluate(lh_params, rh_params, train_batches, val_batches, steps=3):
    train_metrics = evaluate_fold(lh_params, rh_params, train_batches, steps)
    val_metrics = evaluate_fold(lh_params, rh_params, val_batches, steps)
    metrics = {f'train_{k}': v for k, v in train_metrics.items()}
    metrics.update({f'val_{k}': v for k, v in val_metrics.items()})
    return metrics

def evaluate_fold(lh_params, rh_params, batches, steps):
    metrics = {}
    for i in range(steps):
        batch = next(batches)
        batch_metrics = evaluate_batch(lh_params, rh_params, batch)
        metrics = {k: metrics.get(k, 0) + v for k, v in batch_metrics.items()}
    metrics = {k: v / steps for k, v in metrics.items()}
    return metrics
    

def evaluate_batch(lh_params, rh_params, batch):
    metrics = {}
    for hem, params in zip(['lh', 'rh'], [lh_params, rh_params]):
        mse, corr = evaluate_hem(params, batch, hem)
        metrics[f'{hem}_mse'] = mse
        metrics[f'{hem}_corr'] = corr
    return metrics

def evaluate_hem(params, batch, hem):
    x, lh, rh = batch
    lh_hat, rh_hat = apply_fn(params, x)
    mse = jnp.mean((lh_hat - lh) ** 2) if hem == 'lh' else jnp.mean((rh_hat - rh) ** 2)
    # compute the median collumn wise correlation
    corr = pearsonr(lh_hat, lh) if hem == 'lh' else pearsonr(rh_hat, rh)
    return mse, jnp.median(corr)


# function for computing pearson's correlation coefficient for each voxel of a subject's fMRI data
def pearsonr(pred, target):
    def _pearsonr(x, y):
        corr = jnp.corrcoef(x, y)
        return corr[0, 1]
    hem_corr = vmap(_pearsonr)(pred.T, target.T)
    return hem_corr



def update(params, batch, opt_state, hem):
    """update function"""
    loss_fn = lh_loss_fn if hem == 'lh' else rh_loss_fn
    grads = grad(loss_fn)(params, batch)
    updates, opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

lh_update = jit(partial(update, hem='lh'))
rh_update = jit(partial(update, hem='rh'))

In [16]:
config = { 'n_steps': 6000, 'batch_size': 32 }

In [17]:
train(apply_fn, data, config)

100%|██████████| 6000/6000 [01:02<00:00, 95.49it/s] 


0,1
train_lh_corr,▁▁▂▄▅▆▆▇▆▆▇▇▆▇▇▇▇▆▅▇█▇▆▆█▇▇▆▇▆▇█▇▇▆█▇▆▇▆
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▂▃▅▆▆▇▇▆▇█▆▇█▇▇▆▆▇██▆▆▇▇█▆▇▆▇▇▇▇▆█▇▆▇▇
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▂▃▆▅▆▆▆▆▆▆▆▇▆▇▇▆█▇▆█▇▆▆▆▇▅▇▆▇▇▇▆▅▇▇▆▇▆
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▂▄▆▆▅▆▆▇▇▇▇▇▇█▇██▇▇█▇▇▇▆▇▆▇▆▇▇▇▇▆▇█▇█▇
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.20993
train_lh_mse,0.4747
train_rh_corr,0.22422
train_rh_mse,0.47173
val_lh_corr,0.1862
val_lh_mse,0.47516
val_rh_corr,0.19929
val_rh_mse,0.46957


100%|██████████| 6000/6000 [00:58<00:00, 101.81it/s]


0,1
train_lh_corr,▁▁▂▄▅▆▆█▆▆▅▇▆▇█▆▆█▆█▇██▇▆██▇▇▇▆▆▇▇██▇█▇▇
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▂▄▅▆▆▇▆▅▆▇▅█▇▆▆▇▆▇▆▇▇▇▆▆▇▇▇▇▆▆▆▇▇▇▆▇▇▆
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▂▄▇▅▇▆▆▇▆▇▇█▇▇█▆▇▆▆▇▇█▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▂▄▆▆▇▆▇▇▅▇▇█▇▆▇▇▇▆▇▇▇█▇▇▇▇█▆▇▇▇▇▇▆▇▇▆▇
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.2174
train_lh_mse,0.49767
train_rh_corr,0.21649
train_rh_mse,0.50366
val_lh_corr,0.19242
val_lh_mse,0.46517
val_rh_corr,0.21077
val_rh_mse,0.44478


100%|██████████| 6000/6000 [01:02<00:00, 96.52it/s] 


0,1
train_lh_corr,▁▁▂▃▄▆▆▆▆▅▇▅▆▇▅▆▆▆▆▆▆▇▆██▆█▇▆▇▆▇▆▆▇▇▆▆▇█
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▂▄▅▆▆▆▅▆▇▆▆▇▆▆▆▆▆▅▆▇▇██▆▇▇▆▇▇█▆▆▇▇▆▆▆█
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▃▄▅▇▅▇█▇▆▆▇▇▇█▇▅▆█▇▅▆▇█▆▆▆▆█▆▇▆▇▆▆█▇▆▆
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▃▄▄▇▆▇█▇▆▆▆▆▇▇▇▅▅▇▇▅▆▆▇▆▆▆▇█▆▇▆▇▆▆█▆▆▆
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.22521
train_lh_mse,0.56904
train_rh_corr,0.21962
train_rh_mse,0.59278
val_lh_corr,0.17655
val_lh_mse,0.55576
val_rh_corr,0.17079
val_rh_mse,0.5497


100%|██████████| 6000/6000 [01:02<00:00, 96.48it/s] 


0,1
train_lh_corr,▁▁▂▃▅▆▆▆▆▇▄▆▇▇▆▇▇▆▇▅▆▆▇▇▆▇▇▇▇▇▇▆▆█▇▇▇██▆
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▂▃▄▇▆▆▆▇▄▆▆▇▆▇▆▆▇▅▆▆▇▆▆▆▇▇▆▆▇▆▆█▇█▇██▆
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▂▃▄▆▅▆▇▆▅▆▇▆▆▇▆▅▆▆▆▅▆▆▅▆▆▆▆▆▅▆▆▆▆▆█▆▆▇
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▂▃▄▅▅▆▆▆▅▆▇▆▆▆▆▅▆▆▅▅▆▅▆▆▆▅▆▅▅▆▆▆▅▆█▆▆▆
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.1568
train_lh_mse,0.63121
train_rh_corr,0.16075
train_rh_mse,0.64032
val_lh_corr,0.18785
val_lh_mse,0.60436
val_rh_corr,0.19962
val_rh_mse,0.63002


100%|██████████| 6000/6000 [00:59<00:00, 100.74it/s]


0,1
train_lh_corr,▁▁▃▄▄▅▆▇▅▆▇▆▅▇▅▆▇▇▅▇▇▇▇▇▆▆▇▇▆▇█▆▇▆▇▆▆▆▇▇
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▃▄▅▅▇▇▅▆▇▆▅▇▆▆▇▇▆▇▇█▇█▆▅▇▇▆▇█▆█▆▇▆▆▇▇▆
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▃▄▆▅█▇▇▆▇▅▇▆▆▆▇▆▇▅█▇▆▆▇▆▇▆▆▆▇▆▅▅▆▆▆▆▆▅
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▃▄▆▅█▇▆▆▆▅▇▅▆▆▇▅▆▅▇▇▆▆▇▆▆▆▆▆▇▆▆▅▆▅▆▆▆▅
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.24321
train_lh_mse,0.45904
train_rh_corr,0.22159
train_rh_mse,0.45258
val_lh_corr,0.19543
val_lh_mse,0.47145
val_rh_corr,0.18449
val_rh_mse,0.47972


100%|██████████| 6000/6000 [00:59<00:00, 101.05it/s]


0,1
train_lh_corr,▁▁▂▄▄▆▇▇▇▇▇▇▆▇█▆▇▇▇▆▇▆██▆█▇▇█▇▇▆▇▇▆▇▆▇▆▇
train_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_rh_corr,▁▁▂▅▄▆▆▆▆▇▇▇▇▆▇▇▇▆▇▆▇▇█▇▇█▇▇▆▇▇▆██▆█▆▆▇█
train_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_lh_corr,▁▁▂▃▅▅▅▆▆▆▆▇▆▅▇▅▆▆▆▆▆▆▇▆▆█▆▆▇▇▆▇▅▇▇▇▆▆▆▆
val_lh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_rh_corr,▁▁▃▄▄▆▅▇█▆▆▇▅▆▆▅▇▇▆▆▆▇▇▇▆█▇▆▇▇▇▆▆▇██▅▇▆▆
val_rh_mse,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train_lh_corr,0.18584
train_lh_mse,0.49134
train_rh_corr,0.19898
train_rh_mse,0.48771
val_lh_corr,0.15401
val_lh_mse,0.45773
val_rh_corr,0.14988
val_rh_mse,0.44436
