# Neuroscape playground

In [4]:
# imports
import haiku as hk
import optax
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap, lax
import sys; sys.path.append("..")
import src
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import wandb
# black background

In [5]:
plt.style.use('dark_background')
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

## machine learning

In [6]:
n_steps = 5000
batch_size = 32
n_samples = 2 ** 13

In [7]:

args_list = [
    '--model', 'fmri2cat',
    '--roi', 'V1d',
    '--machine', 'local',
    '--subject', 'subj05',
    '--batch_size', str(batch_size),
    '--n_samples', str(n_samples),
    '--n_steps', str(n_steps),
    ]

In [8]:
config, args = src.get_setup(args_list)
# if variable called lh not in scope
if 'lh' not in locals():
    train_loader, val_loader, _ = src.get_loaders(args, config)
    img, cat, sup, cap, lh, rh =  next(train_loader)
    img, cat, sup, cap, lh, rh =  next(val_loader)

100%|██████████| 6297/6297 [00:49<00:00, 126.30it/s]
100%|██████████| 1575/1575 [00:12<00:00, 129.73it/s]


In [9]:
def target_distribution(train_loader, steps=n_samples // batch_size):
    """Compute the target distribution for the training data."""
    _, cat, _, _, _, _ = next(train_loader)
    freqs = jnp.zeros_like(cat[0])
    for _ in tqdm(range(steps)):
        _, cat, _, _, _, _ = next(train_loader)
        freqs += jnp.sum(cat, axis=0)
    probs = freqs / (steps * batch_size)
    return probs

def plot_metrics(metrics):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5), dpi=100)
    for k, v in metrics.items():
        if k.endswith('loss'):
            axes[0].plot(v, label=k[:-5])
        if k.endswith('f1'):
            axes[1].plot(v, label=k[:-3])
    axes[0].set_title('loss')
    axes[0].legend()
    axes[1].set_title('f1')
    axes[1].legend()
    plt.show()

In [10]:
def network_fn(x):
    mlp = hk.Sequential([
        hk.Linear(128), jax.nn.gelu,
        hk.Linear(128), jax.nn.gelu,
        hk.Linear(128), jax.nn.gelu,
        hk.Linear(80), jax.nn.sigmoid,
    ])
    return mlp(x)

In [11]:
init, forward = hk.without_apply_rng(hk.transform(network_fn))
scheduler = optax.cosine_decay_schedule(1e-3, n_steps, 1e-5)
optimizer = optax.adam(scheduler)
probs = target_distribution(train_loader)

100%|██████████| 256/256 [00:04<00:00, 63.03it/s]


In [16]:
@jit
def update(params, x, y, opt_state):
    grads = grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

def f1_score(params, x, y, pred=None):
    pred = forward(params, x) > 0.5 if pred is None else pred
    pred, y = pred.astype('int32'), y.astype('int32')
    tp = jnp.sum(pred * y)
    fp = jnp.sum(pred * (1 - y))
    fn = jnp.sum((1 - pred) * y)
    return 2 * tp / (2 * tp + fp + fn)

def loss_fn(params, x, y, pred=None):
    # soft f1 loss (account for base line with 0's and 1's)
    pred = forward(params, x) if pred is None else pred
    tp = jnp.sum(pred * y)
    fp = jnp.sum(pred * (1 - y))
    fn = jnp.sum((1 - pred) * y)
    return 1 - (2 * tp) / (2 * tp + fp + fn)


def baseline(params, x, y, probs, rng):
    pred = random.uniform(next(rng), (x.shape[0], 80)) < probs
    loss = loss_fn(params, x, y, pred)
    f1 = f1_score(params, x, y, pred)
    return loss, f1

def evaluate(params, train_loader, val_loader, probs, rng, steps=20):
    train_loss, train_f1, val_loss, val_f1, base_loss, base_f1 = [], [], [], [], [], []
    for _ in range(steps):
        _, y, _, _, lh, rh = next(train_loader)       # training
        x = jnp.concatenate([lh, rh], axis=1)
        train_loss.append(loss_fn(params, x, y))
        train_f1.append(f1_score(params, x, y))
        _, val_y, _, _, val_lh, val_rh = next(val_loader)         # validation
        val_x = jnp.concatenate([val_lh, val_rh], axis=1)
        val_loss.append(loss_fn(params, val_x, val_y))
        val_f1.append(f1_score(params, val_x, val_y))
        b_loss, b_f1 = baseline(params, val_x, val_y, probs, rng) # baseline
        base_loss.append(b_loss)
        base_f1.append(b_f1)
    return dict(
        train_loss=np.mean(train_loss),
        train_f1=np.mean(train_f1),
        val_loss=np.mean(val_loss),
        val_f1=np.mean(val_f1),
        base_loss=np.mean(base_loss),
        base_f1=np.mean(base_f1),
    )

def train(params, state, train_loader, val_loader, probs, rng, steps=n_steps):
    wandb.init(project='neuroscope', entity='syrkis', config=args)
    for step in tqdm(range(steps)):
        _, y, _, _, lh, rh = next(train_loader)
        x = jnp.concatenate([lh, rh], axis=1)
        params, state = update(params, x, y, state)
        if step % (steps // 100) == 0:
            wandb.log(evaluate(params, train_loader, val_loader, probs, rng))
    wandb.finish()
    return params, state


In [19]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = init(next(rng), jnp.concatenate([lh, rh], axis=1))
state = optimizer.init(params)
params, state = train(params, state, train_loader, val_loader, probs, rng)

0,1
base_f1,▁▂▁▂▃▄▄▁▂▃▂▄▃▄▇▇▆▅▇▆██▆▇▆▅▅▃▃▃▆▄▂▄▃▅▅▄▄▄
base_loss,█▇█▇▆▅▅█▇▆▇▅▆▅▂▂▃▄▂▃▁▁▃▂▃▄▄▆▆▆▃▅▇▅▆▄▄▅▅▅
train_f1,▁▃▃▆▇▆▅▄▆▅▅▇▄▅▅▆▆▅▆▆▆▅▆▅▆▆▆▆█▅▅█▅▅▆▆▅▆▇▆
train_loss,█▆▆▃▂▃▄▅▃▄▄▂▅▄▄▃▃▄▃▃▃▄▃▄▃▃▃▃▁▄▄▁▄▄▃▃▄▃▂▃
val_f1,▁▂▄▅▆██▅▆▅▆▅▅▆▆▆▅▅▄▆▇▆▆▅▆▅▆▆▆▅▆▆▅▆▅▆▆▅▆▆
val_loss,█▇▅▄▃▁▁▄▃▄▃▄▄▃▃▃▄▄▅▃▂▃▃▄▃▄▃▃▃▄▃▃▄▃▄▃▃▄▃▃

0,1
base_f1,0.13507
base_loss,0.86493
train_f1,0.25028
train_loss,0.74972
val_f1,0.26135
val_loss,0.73865


100%|██████████| 5000/5000 [03:04<00:00, 27.04it/s]
