In [156]:
import optax
import jax
from jax import grad, jit
import jax.numpy as jnp
import random
import haiku as hk
import matplotlib.pyplot as plt
import wandb
from PIL import Image
import numpy as np
from tqdm import tqdm
from functools import partial
from src.utils import get_args_and_config
from src.data import load_data

In [157]:
args, _ = get_args_and_config()
data = load_data(args) if 'data' not in locals() else eval('data')
train_data = {subject: data[subject]['folds'] for subject in data.keys()}
test_data = {subject: data[subject]['test'] for subject in data.keys()}

In [167]:
# hypere param sweep config in config
config = dict(
    embed_size=1024,
    hidden_dim=1024,
    n_layers=2
)

In [168]:
def embedding_fn(fmri, config):
    """embedding function"""
    n_layers = config['n_layers']
    embed_size = config['embed_size']
    img_mlp = hk.Sequential([
        # linear layer to get embedding without bias
        hk.Linear(embed_size, with_bias=False),
        jax.nn.tanh,
    ])
    return img_mlp(fmri)

def decoder_fn(z, config):
    """decoder function"""
    n_layers = config['n_layers']
    hidden_dim = config['hidden_dim']
    embed_size = config['embed_size']
    img_mlp = hk.Sequential([
        hk.Linear(hidden_dim, with_bias=False),
        jax.nn.tanh,
    ])
    img_deconv = hk.Sequential([
        hk.Conv2DTranspose(3, kernel_shape=4, stride=2, padding='SAME'),
        jax.nn.sigmoid,
        # go from 64x64x3 to to 224x224x3
        hk.Conv2DTranspose(3, kernel_shape=4, stride=2, padding='SAME'),
        jax.nn.sigmoid,
    ])
    z = img_mlp(z)
    z = z.reshape((-1, int(hidden_dim ** 0.5), int(hidden_dim ** 0.5), 1))
    z = img_deconv(z)
    return z

In [169]:
embedding_fn = partial(embedding_fn, config=config)
decoder_fn = partial(decoder_fn, config=config)
init_embed, apply_embed = hk.without_apply_rng(hk.transform(embedding_fn))
init_decoder, apply_decoder = hk.without_apply_rng(hk.transform(decoder_fn))


In [170]:
@jit
def loss_fn(params, lh, rh, img):
    """loss function"""
    lh_embed = apply_embed(params[0], lh)
    rh_embed = apply_embed(params[1], rh)
    embed = jnp.concatenate([lh_embed, rh_embed], axis=-1)
    img_hat = apply_decoder(params[2], embed)
    return jnp.mean((img - img_hat) ** 2)

In [171]:
@jit
def update(params, opt_state, train_batches):
    """update function"""
    grads = grad(loss_fn)(params, x1, x2, y)
    updates, opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state

In [172]:
def get_batches(data_split, batch_size=32):
    """get a batch of data"""
    while True:
        fold_perm = np.random.permutation(len(data_split['subj01']))  # either 5 or 1 (train or val)
        for fold_idx in fold_perm:
            subject_folds = {key: value[fold_idx] for key, value in train_data.items()}
            n_samples = min([len(value[0]) for value in subject_folds.values()])
            sample_perm = np.random.permutation(n_samples)
            for i in range(0, n_samples, batch_size):
                for subject in subject_folds.keys():
                    batch_data = (subject_folds[subject][0][sample_perm[i:i+batch_size]],
                                  subject_folds[subject][1][sample_perm[i:i+batch_size]],
                                  subject_folds[subject][2][sample_perm[i:i+batch_size]])
                    batch_data_gpu = jax.device_put(batch_data)
                    yield batch_data_gpu
                

In [173]:
train_batches = get_batches(train_data)
dummy_batches = [next(train_batches) for _ in range(len(train_data))]
lh_embed_params = [init_embed(jax.random.PRNGKey(42), batch[0]) for batch in dummy_batches]
rh_embed_params = [init_embed(jax.random.PRNGKey(42), batch[1]) for batch in dummy_batches]
decoder_params = init_decoder(jax.random.PRNGKey(42), jnp.zeros((1, config['embed_size'] * 2)))
params = (lh_embed_params, rh_embed_params, decoder_params)
opt = optax.adam(1e-3)
opt_state = opt.init(params)

In [175]:
lens = []
for i in tqdm(range((10_000*8*10) //  32)):
    params, opt_state = update(params, opt_state, train_batches)
    if i % 100 == 0:
        loss = loss_fn(params, *next(train_batches))
        lens.append(loss)


  0%|          | 0/25000 [00:00<?, ?it/s]


TypeError: params argument does not appear valid. It should be a mapping but is of type <class 'list'>. For reference the parameters for apply are `apply(params, rng, ...)`` for `hk.transform` and `apply(params, state, rng, ...)` for `hk.transform_with_state`.
The argument was: [{'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[18978,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[19004,1024])>with<DynamicJaxprTrace(level=4/0)>}}, {'linear': {'w': Traced<ShapedArray(float32[18981,1024])>with<DynamicJaxprTrace(level=4/0)>}}].