# Neuroscape playground

In [1]:
# imports
import haiku as hk
import optax
import jax
import jax.numpy as jnp
from jax import random, grad, jit
import sys

sys.path.append("..")
import src
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import wandb

In [None]:
plt.style.use("dark_background")
jax.devices()

## data

In [None]:
n_steps = 3000
batch_size = 32
n_samples = 2**13

In [None]:
args_list = [
    "--model",
    "fmri2cat",
    "--roi",
    "V1v,V1d",
    "--machine",
    "local",
    "--subject",
    "subj05",
    "--batch_size",
    str(batch_size),
    "--n_samples",
    str(n_samples),
    "--n_steps",
    str(n_steps),
]

In [None]:
args, config = src.get_setup(args_list)
# if variable called lh not in scope
if "lh" not in locals():
    train_loader, val_loader, _ = src.get_loaders(args, config)
    img, cat, sup, cap, lh, rh = next(train_loader)
    img, cat, sup, cap, lh, rh = next(val_loader)

## utils

In [None]:
def target_distribution(train_loader, steps=n_samples // batch_size):
    """Compute the target distribution for the training data."""
    _, cat, _, _, _, _ = next(train_loader)
    freqs = jnp.zeros_like(cat[0])
    for _ in tqdm(range(steps)):
        _, cat, _, _, _, _ = next(train_loader)
        freqs += jnp.sum(cat, axis=0)
    probs = freqs / (steps * batch_size)
    return probs


def plot_metrics(metrics):
    fig, axes = plt.subplots(1, 2, figsize=(15, 5), dpi=100)
    for k, v in metrics.items():
        if k.endswith("loss"):
            axes[0].plot(v, label=k[:-5])
        if k.endswith("f1"):
            axes[1].plot(v, label=k[:-3])
    axes[0].set_title("loss")
    axes[0].legend()
    axes[1].set_title("f1")
    axes[1].legend()
    plt.show()

## model

In [None]:
mlp = hk.Sequential(
    [
        hk.Flatten(),
        hk.Linear(512),
        jax.nn.relu,
        hk.Linear(512),
        jax.nn.relu,
        hk.Linear(80),
        jax.nn.sigmoid,
    ]
)

deconv = hk.Sequential(
    [
        hk.Conv2D(32, kernel_shape=3, padding="SAME"),
        jax.nn.relu,
        hk.Conv2DTranspose(32, kernel_shape=3, padding="SAME"),
        jax.nn.relu,
        hk.Conv2DTranspose(3, kernel_shape=3, padding="SAME"),
        jax.nn.sigmoid,
    ]
)

In [None]:
def network_fn(x):
    z = mlp(x)
    z = z.reshape((batch_size, 32, 32, 1))
    z = deconv(z)
    return z

In [None]:
init, forward = hk.without_apply_rng(hk.transform(network_fn))
optimizer = optax.adam(1e-4)
probs = target_distribution(train_loader)

In [None]:
@jit
def update(params, x, y, opt_state):
    grads = grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state


def f1_score(params, x, y, pred=None):
    pred = forward(params, x) > 0.5 if pred is None else pred > 0.5
    pred, y = pred.flatten(), y.flatten()
    tp = jnp.sum(pred * y)
    fp = jnp.sum(pred * (1 - y))
    fn = jnp.sum((1 - pred) * y)
    return 2 * tp / (2 * tp + fp + fn)


def loss_fn(params, x, y, pred=None):
    pred = forward(params, x) if pred is None else pred
    pred, y = pred.flatten(), y.flatten()
    tp = jnp.sum(pred * y)
    fp = jnp.sum(pred * (1 - y))
    fn = jnp.sum((1 - pred) * y)
    return 1 - (2 * tp / (2 * tp + fp + fn))  # 95 % sure this is correct


def baseline(params, x, y, rng):
    x = random.uniform(next(rng), x.shape)
    pred = forward(params, x)
    loss = loss_fn(params, x, y, pred)
    f1 = f1_score(params, x, y, pred)
    return loss, f1


def evaluate(params, train_loader, val_loader, probs, rng, steps=20):
    train_loss, train_f1 = [], []
    val_loss, val_f1 = [], []
    base_loss, base_f1 = [], []
    for _ in range(steps):
        _, y, _, _, lh, rh = next(train_loader)  # training
        x = jnp.concatenate([lh, rh], axis=1)
        train_loss.append(loss_fn(params, x, y))
        train_f1.append(f1_score(params, x, y))
        _, val_y, _, _, val_lh, val_rh = next(val_loader)  # validation
        val_x = jnp.concatenate([val_lh, val_rh], axis=1)
        val_loss.append(loss_fn(params, val_x, val_y))
        val_f1.append(f1_score(params, val_x, val_y))
        b_loss, b_f1 = baseline(params, x, y, probs, rng)  # baseline
        base_loss.append(b_loss)
        base_f1.append(b_f1)
    return dict(
        train_loss=np.mean(train_loss),
        train_f1=np.mean(train_f1),
        val_loss=np.mean(val_loss),
        val_f1=np.mean(val_f1),
        base_loss=np.mean(base_loss),
        base_f1=np.mean(base_f1),
    )


def train(params, state, train_loader, val_loader, probs, rng, steps=n_steps):
    wandb.init(project="neuroscope", entity="syrkis", config=args)
    for step in tqdm(range(steps)):
        img, cat, sup, cap, fmri = next(train_loader)
        params, state = update(params, fmri, img, state)
        if step % (steps // 100) == 0:
            wandb.log(evaluate(params, train_loader, val_loader, probs, rng))
    wandb.finish()
    return params, state

## training

In [None]:
rng = hk.PRNGSequence(jax.random.PRNGKey(42))
params = init(next(rng), jnp.concatenate([lh, rh], axis=1))
state = optimizer.init(params)
params, state = train(params, state, train_loader, val_loader, probs, rng)