In [4]:
import optax
import jax
import jax.numpy as jnp
import haiku as hk
import matplotlib.pyplot as plt
import wandb
from functools import partial
from src.utils import get_args_and_config
from src.data import get_data

In [24]:
args, config = get_args_and_config()
data = get_data(args)['subj01'][0] if 'data' not in locals() else eval('data')

In [25]:
# hypere param sweep config in config
def get_hyper_params(config):
    """get hyper params"""
    config['n_units'] = 256
    config['n_layers'] = 2
    config['hidden_size'] = 256
    config['embed_size'] = 256
    config['latent_dim'] = 128
    config['dropout'] = 0.2
    config['beta'] = 0.5
    config['alpha'] = 0.5
    return config

config = get_hyper_params(config)

In [18]:
def embedding_fn(fmri, config):
    """embedding function"""
    n_units = config['n_units']
    n_layers = config['n_layers']
    embed_size = config['embed_size']
    img_mlp = hk.Sequential([
        hk.nets.MLP([n_units] * n_layers, activation=jnp.tanh),
        hk.Linear(embed_size),  # image dim (from alexnet PCA)
        jax.nn.sigmoid,
    ])
    return img_mlp(fmri)

def encoder_fn(img, config):
    """encoder function"""
    n_units = config['n_units']
    n_layers = config['n_layers']
    img_mlp = hk.Sequential([
        hk.nets.MLP([n_units] * n_layers, activation=jnp.tanh),
        hk.Linear(100),
        jax.nn.sigmoid,
    ])
    return img_mlp(img)

def decoder_fn(fmri, config):
    """decoder function"""
    n_units = config['n_units']
    n_layers = config['n_layers']
    img_mlp = hk.Sequential([
        hk.nets.MLP([n_units] * n_layers, activation=jnp.tanh),
        hk.Linear(100),
        jax.nn.sigmoid,
    ])
    return img_mlp(fmri)


embedding_fn = partial(embedding_fn, config=config)
encoder_fn = partial(encoder_fn, config=config)
decoder_fn = partial(decoder_fn, config=config)


In [23]:
def train():
    hyper_params = get_hyper_params(config)

    init_embedding, apply_embedding = hk.without_apply_rng(hk.transform(embedding_fn))
    init_encoder, apply_encoder = hk.without_apply_rng(hk.transform(encoder_fn))
    init_decoder, apply_decoder = hk.without_apply_rng(hk.transform(decoder_fn))

    embedding_params = init_embedding(jax.random.PRNGKey(42), data[0][0])
    encoder_params = init_encoder(jax.random.PRNGKey(42), data[0][0])
    decoder_params = init_decoder(jax.random.PRNGKey(42), data[0][0])

    lh_hat = apply_encoder(encoder_params, data[0][0])
    rh_hat = apply_encoder(encoder_params, data[0][0])
    fmri = jnp.concatenate([lh_hat, rh_hat], axis=1)
    img_hat = apply_decoder(decoder_params, fmri)
    print(lh_hat.shape, rh_hat.shape)


train()

ValueError: 'mlp/~/linear_0/w' with retrieved shape (100, 256) does not match shape=[200, 256] dtype=dtype('float32')