# Neuroscape playground

In [None]:
# imports
from multiprocessing import Pool
import sys
sys.path.append("..")
import src
import jax
from jax import random, grad, jit, vmap, lax
import jax.numpy as jnp
import importlib
from matplotlib import pyplot as plt
import os
import numpy as np
import pandas as pd
from nilearn import datasets, plotting, maskers
from tqdm import tqdm
# black background

In [None]:
importlib.reload(src);
plt.style.use('dark_background')
plt.rcParams['font.family'] = 'Times New Roman'

## machine learning

In [None]:
n_steps = 100
batch_size = 32
n_samples = 2000

In [None]:

args_list = [
    '--model', 'fmri2cat',
    '--roi', 'V1d',
    '--machine', 'local',
    '--subject', 'subj05',
    '--batch_size', str(batch_size),
    '--n_samples', str(n_samples),
    '--n_steps', str(n_steps),
    ]

In [None]:
reload_data = True  # flip to false after first run
config, args = src.get_setup(args_list)
if reload_data:
    train_loader, val_loader, _ = src.get_loaders(args, config)
    next(train_loader), next(val_loader);

In [None]:
rng = jax.random.PRNGKey(0)
layer_sizes = [rh.shape[-1] + lh.shape[-1], 128, 128, 80]
params = src.model.init_mlp(layer_sizes, rng)
metrics = {'train_f1': [], 'val_f1': [], 'train_loss': [], 'val_loss': [], 'baseline_f1': [], 'baseline_loss': []}
freqs, n = np.zeros(80), 0

In [None]:
def loss_fn(params, x, y, pred=None):
    pred = src.model.forward_mlp(params, x) if pred is None else pred
    return -jnp.mean(y * jnp.log(pred) + (1 - y) * jnp.log(1 - pred))

def f1_score(params, x, y, pred=None):
    pred = src.model.forward_mlp(params, x) if pred is None else pred
    pred = pred > 0.5
    tp = jnp.sum(pred * y)
    fp = jnp.sum(pred * (1 - y))
    fn = jnp.sum((1 - pred) * y)
    return tp / (tp + 0.5 * (fp + fn))

def baseline(params, x, y, probs):
    pred = np.random.rand(*y.shape) < probs
    loss = loss_fn(params, x, y, pred)
    f1 = f1_score(params, x, y, pred)
    return loss, f1


grad_fn = jit(grad(loss_fn))

In [None]:
steps = 4000
for i in tqdm(range(steps)):
    img, cat, sup, cap, lh, rh = next(train_loader)
    lrh = np.concatenate([rh, lh], axis=-1)
    freqs, n = freqs + np.sum(cat, axis=0), n + cat.shape[0]
    val_img, val_cat, val_sup, val_cap, val_lh, val_rh = next(val_loader)
    val_lrh = np.concatenate([val_rh, val_lh], axis=-1)
    pred = src.model.forward_mlp(params, lrh)
    grads = grad_fn(params, lrh, cat)
    params = [(w - 0.01 * dw, b - 0.01 * db) for (w, b), (dw, db) in zip(params, grads)]
    loss = loss_fn(params, lrh, cat)
    f1 = f1_score(params, lrh, cat)
    val_loss = loss_fn(params, val_lrh, val_cat)
    val_f1 = f1_score(params, val_lrh, val_cat)
    baseline_loss, baseline_f1 = baseline(params, lrh, cat, freqs / n)
    metrics['baseline_f1'].append(baseline_f1)
    metrics['baseline_loss'].append(baseline_loss)
    metrics['train_f1'].append(f1)
    metrics['val_f1'].append(val_f1)
    metrics['train_loss'].append(loss)
    metrics['val_loss'].append(val_loss)



# src.plot_metrics(metrics);

In [None]:

rolling_average = lambda lst: [np.mean(lst[i:i+30]) for i in range(len(lst) - 30)]

# params, metrics = src.train(params, metrics, config, args, train_loader, val_loader)
fig, axes = plt.subplots(1, 2, figsize=(15, 5), dpi=100)
axes[0].plot(rolling_average(metrics['train_loss']), label='train')
axes[0].plot(rolling_average(metrics['val_loss']), label='val')
axes[0].plot(rolling_average(metrics['baseline_loss']), label='baseline')
axes[0].set_title('Binary Cross Entropy Loss')
axes[0].legend()
axes[1].plot(rolling_average(metrics['train_f1']), label='train')
axes[1].plot(rolling_average(metrics['val_f1']), label='val')
axes[1].plot(rolling_average(metrics['baseline_f1']), label='baseline')
axes[1].set_title('F1 Score')
axes[1].legend()

In [None]:
param_sizes = [p.size for p in jax.tree_util.tree_flatten(params)[0]]
num_params = sum(param_sizes)
print(f'Number of parameters: {num_params}')

## learn haiku