In [1]:
from tqdm import tqdm
import optax
import haiku as hk
import jax
from jax import jit, grad
import jax.numpy as jnp
from functools import partial
from src.data import load_subject, make_kfolds

In [2]:
args = {'image_size': 32}
subject = load_subject('subj05', args['image_size'])
config = {'batch_size': 32}

In [23]:
def network_fn(fmri, rng=None):
    # fmri is batch x 1000
    linear = hk.Sequential([    # encoder kind of
        hk.Linear(300), jax.nn.relu,
        hk.Linear(100), jax.nn.relu,
        
        hk.Linear(10)
    ])
    deconv = hk.Sequential([    # decoder kind of
        hk.Linear(100), jax.nn.relu,
        hk.Linear(300), jax.nn.relu,
        hk.Linear(args['image_size']*args['image_size']*3)
    ])
    return deconv(linear(fmri)).reshape(-1, args['image_size'], args['image_size'], 3)

init, apply = hk.transform_with_state(network_fn)

In [24]:
def loss_fn(params, rng, fmri, img):
    pred = apply(params, rng, fmri)  # Provide the rng key here
    loss = jnp.mean((pred - img) ** 2)
    return loss

def update_fn(params, fmri, img, opt_state, rng, opt):
    rng, key = jax.random.split(rng)
    grads = grad(loss_fn)(params, key, fmri, img)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state

def train_loop(opt, train_loader, val_loader, n_steps):
    rng = jax.random.PRNGKey(0)
    rng, key = jax.random.split(jax.random.PRNGKey(0))
    params = init(key, next(train_loader)[0])
    opt_state = opt.init(params)
    update = partial(update_fn, opt=opt)
    for step in range(n_steps):
        rng, key = jax.random.split(rng)
        lh, rh, img = next(train_loader)
        params, opt_state = update(params, lh, img, opt_state, key)

def train_folds(kfolds):
    opt = optax.lion(1e-3)
    for train_loader, val_loader in kfolds:
        train_loop(opt, train_loader, val_loader, n_steps=100)

In [22]:
kfolds = make_kfolds(subject, config)
train_folds(kfolds)

KeyboardInterrupt: 

In [26]:
hk.dropout(jax.random.PRNGKey(0), 0.5, jnp.ones((10, 10)))

Array([[2., 0., 0., 0., 2., 0., 0., 2., 0., 2.],
       [2., 2., 0., 0., 2., 0., 2., 2., 2., 0.],
       [2., 0., 0., 0., 0., 0., 2., 2., 0., 2.],
       [0., 2., 2., 0., 2., 0., 0., 2., 0., 2.],
       [0., 2., 0., 0., 0., 0., 2., 2., 2., 0.],
       [0., 0., 0., 0., 2., 2., 0., 0., 0., 2.],
       [0., 0., 2., 0., 0., 0., 0., 2., 2., 0.],
       [2., 0., 0., 0., 0., 2., 0., 0., 2., 0.],
       [0., 0., 2., 2., 2., 2., 2., 0., 2., 0.],
       [2., 2., 2., 2., 0., 2., 2., 0., 0., 2.]], dtype=float32)