In [31]:
import optax
import jax
from jax import grad, jit
import jax.numpy as jnp
import haiku as hk
import matplotlib.pyplot as plt
import wandb
from tqdm import tqdm
from functools import partial
from src.utils import get_args_and_config
from src.data import get_data

In [23]:
args, _ = get_args_and_config()
data = get_data(args)['subj01'][0] if 'data' not in locals() else eval('data')

In [24]:
# hypere param sweep config in config
config = dict(
    embed_size=1000,
    hidden_dim=1000,
    n_layers=2
)

In [25]:
def embedding_fn(fmri, config):
    """embedding function"""
    n_layers = config['n_layers']
    embed_size = config['embed_size']
    img_mlp = hk.Sequential([
        # linear layer to get embedding without bias
        hk.Linear(embed_size, with_bias=False),
        jax.nn.tanh,
    ])
    return img_mlp(fmri)

def encoder_fn(img, fmri, config):
    """encoder function"""
    hidden_dim = config['hidden_dim']
    n_layers = config['n_layers']
    img_mlp = hk.Sequential([
        hk.nets.MLP([hidden_dim] * n_layers, activation=jnp.tanh),
        hk.Linear(fmri.shape[-1]),
        jax.nn.sigmoid,
    ])
    return img_mlp(img)

def decoder_fn(fmri, img, config):
    """decoder function"""
    hidden_dim = config['hidden_dim']
    n_layers = config['n_layers']
    img_mlp = hk.Sequential([
        hk.nets.MLP([hidden_dim] * n_layers, activation=jnp.tanh),
        hk.Linear(img.shape[-1]),
        jax.nn.sigmoid,
    ])
    return img_mlp(fmri)


embedding_fn  = partial(embedding_fn, config=config)
lh_encoder_fn = partial(encoder_fn, fmri=data[0][1], config=config)
rh_encoder_fn = partial(encoder_fn, fmri=data[0][2], config=config)
decoder_fn    = partial(decoder_fn, img=data[0][0], config=config)


In [26]:
init_embed, apply_embed           = hk.without_apply_rng(hk.transform(embedding_fn))
init_lh_encoder, apply_lh_encoder = hk.without_apply_rng(hk.transform(lh_encoder_fn))
init_rh_encoder, apply_rh_encoder = hk.without_apply_rng(hk.transform(rh_encoder_fn))
init_decoder, apply_decoder       = hk.without_apply_rng(hk.transform(decoder_fn))

In [32]:
lh_embed_params = init_embed(jax.random.PRNGKey(42), jnp.zeros((1, 19004)))
rh_embed_params = init_embed(jax.random.PRNGKey(42), jnp.zeros((1, 20544)))
decoder_params  = init_decoder(jax.random.PRNGKey(42), jnp.zeros((1, config['embed_size'] * 2)))
params = (lh_embed_params, rh_embed_params, decoder_params)

In [34]:
@jit
def loss_fn(params, lh, rh, img):
    """loss function"""
    lh_embed = apply_embed(params[0], lh)
    rh_embed = apply_embed(params[1], rh)
    embed = jnp.concatenate([lh_embed, rh_embed], axis=-1)
    img_hat = apply_decoder(params[2], embed)
    return jnp.mean((img - img_hat) ** 2)

In [36]:
opt = optax.adamw(0.001)
opt_state = opt.init(params)

In [42]:
@jit
def update(params, opt_state, lh, rh, img):
    """update function"""
    grads = grad(loss_fn)(params, lh, rh, img)
    updates, opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

In [45]:
for i in range(100):
    for img, lh, rh, _ in data:
        loss = loss_fn(params, lh, rh, img) 
        params, opt_state = update(params, opt_state, lh, rh, img)
        print(loss)

84.08149
83.36416
83.91787
85.05655
83.38456
84.01379
83.26535
83.81926
84.98988
83.32216
83.94138
83.16627
83.7529
84.9338
83.231606
83.858955
83.118546
83.67582
84.83652
83.14296
83.77377
83.04422
83.5734
84.746574
83.09522
83.701385
82.951126
83.50966
84.712616
83.00717
83.61382
82.892914
83.48225
84.63931
82.93975
83.58861
82.86928
83.41724
84.576355
82.919624
83.544426
82.79644
83.370186
84.52695
82.84847
83.483246
82.742256
83.29314
84.46873
82.80004
83.42263
82.673706
83.28186
84.42301
82.75056
83.36524
82.65597
83.24523
84.40068
82.71333
83.33031
82.62821
83.19574
84.33155
82.65427
83.30718
82.58938
83.15142
84.306145
82.6162
83.29269
82.561
83.102036
84.27894
82.60348
83.262
82.517204
83.05464
84.23411
82.583206
83.171585
82.44595
83.035034
84.173065
82.48716
83.14812
82.3975
82.96004
84.14465
82.44274
83.063354
82.358055
82.92559
84.07737
82.40021
83.04093
82.30195
82.87868
84.043106
82.35311
82.99813
82.25789
82.817154
83.97923
82.306114
82.9514
82.212944
82.771126
83.93689
