In [199]:
import jax
from jax import vmap, jit, lax
import jax.numpy as jnp
from jax.tree_util import tree_leaves
import numpy as np
import optax
import tensorflow_datasets as tfds
from tqdm import tqdm

from src.data import load_subject, make_kfolds
from src.utils import CONFIG
from src.train import hyperparam_fn
from src.plots import plot_brain, plot_multiples, make_halves

In [200]:
# GLOBALS
image_size = 28   #128 # CONFIG['image_size']
latent_dim = 128
kernel_size = 4
lr = 1e-2
depth = 2  # also deconv depth
max_channels = 2 ** (depth + 2)
precision = jnp.float32

In [201]:
""" subject = load_subject('subj05', image_size=image_size, precision=precision)
hyperparams = hyperparam_fn()
kfolds = make_kfolds(subject, hyperparams)
loader, _ = next(kfolds) """

" subject = load_subject('subj05', image_size=image_size, precision=precision)\nhyperparams = hyperparam_fn()\nkfolds = make_kfolds(subject, hyperparams)\nloader, _ = next(kfolds) "

In [238]:
# mnist = tfds.as_numpy(tfds.load('mnist', batch_size=-1))
# digits = mnist['train']['image'].reshape(-1, 16, 28, 28, 1) / 255.0
# fashion_mnist = tfds.as_numpy(tfds.load('fashion_mnist', batch_size=-1))
# digits = fashion_mnist['train']['image'].reshape(-1, 16, 28, 28, 1) / 255.0
omniglot = tfds.as_numpy(tfds.load('omniglot', batch_size=-1))
digits = omniglot['train']['image'].reshape(-1, 16, 105, 105, 3) / 255.0

In [239]:
@jit
def conv2d(x, w):
    return jax.lax.conv_general_dilated(
        x, w, 
        window_strides=(2, 2), 
        padding='SAME',
        dimension_numbers=("NHWC", "HWIO", "NHWC"))

@jit
def upscale_nearest_neighbor(x, scale_factor=2):
    # Assuming x has shape (batch, height, width, channels)
    b, h, w, c = x.shape
    x = x.reshape(b, h, 1, w, 1, c)
    x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
    return x.reshape(b, h * scale_factor, w * scale_factor, c)

@jit
def deconv2d(x, w):
    x_upscaled = upscale_nearest_neighbor(x)
    return lax.conv_transpose(
        x_upscaled, w, 
        strides=(1, 1), 
        padding='SAME',
        dimension_numbers=("NHWC", "HWIO", "NHWC")) 

In [232]:
def conv(x, w, b, activation=jax.nn.gelu):
    return activation(conv2d(x, w) + b)

def deconv(x, w, b, activation=jax.nn.gelu):
    return activation(deconv2d(x, w) + b)

In [240]:
def encoder(x, params):
    conv_params = params['encoder']['conv']
    fc_params = params['encoder']['fc']
    for idx, (w, b) in enumerate(conv_params):
        x = conv(x, w, b)
    x = x.reshape(x.shape[0], (image_size // (2 ** depth)) ** 2 * max_channels)
    x = jnp.dot(x, fc_params[0]) + fc_params[1]
    return jax.nn.gelu(x)

def decoder(z, params):
    deconv_params = params['decoder']['deconv']
    fc_params = params['decoder']['fc']
    x = jnp.dot(z, fc_params[0]) + fc_params[1]
    x = x.reshape(x.shape[0], (image_size // (2 ** depth)), (image_size // (2 ** depth)), max_channels)
    for idx, (w, b) in enumerate(deconv_params):
        x = deconv(x, w, b)
    return jnp.clip(x, 0, 1)   # sigmoid might be better, but this is faster



def init_layer_params(key, shape, scale=1e-1):
    w = jax.random.normal(key, shape) * scale
    b = jnp.zeros(shape[-1])
    return w, b


def init_encoder(key, scale=1e-1):
    conv_shapes = [
        (kernel_size, kernel_size, 3 if i == 0 else 2 ** (i + 2), 8 if i == 0 else 2 ** (i + 3))
        for i in range(depth)
    ]
    conv_params = [init_layer_params(key, shape, scale) for shape in conv_shapes]
    fc_shape = (max_channels * (image_size // (2 ** depth)) ** 2, latent_dim)
    fc_params = init_layer_params(key, fc_shape, scale)
    return {'conv': conv_params, 'fc': fc_params}


def init_decoder(key, scale=1e-1):
    deconv = [
        (kernel_size, kernel_size, 2 ** (depth - i + 2), 3 if i == depth - 1 else 2 ** (depth - i + 1))
        for i in range(depth)
    ]
    deconv_params = [init_layer_params(key, shape, scale) for shape in deconv]
    fc_shape = (latent_dim, max_channels * (image_size // (2 ** depth)) ** 2)
    fc_params = init_layer_params(key, fc_shape, scale)
    return {'deconv': deconv_params, 'fc': fc_params}

def init_params(key):
    key_encoder, key_decoder = jax.random.split(key)
    return {'encoder': init_encoder(key_encoder), 'decoder': init_decoder(key_decoder)}


@jit
def loss_fn(params, x):
    z = encoder(x, params)
    x_hat = decoder(z, params)
    return jnp.mean((x - x_hat) ** 2)

In [241]:
key, subkey = jax.random.split(jax.random.PRNGKey(0))
params = init_params(subkey)
params = jax.tree_map(lambda x: x.astype(precision), params)
n_params = sum([np.prod(v.shape) for v in tree_leaves(params)])

In [242]:

n_params

206515

In [243]:
opt = optax.adamw(lr)
opt_state = opt.init(params)

def update_fn(params, x, opt_state):
    loss, grads = jax.value_and_grad(loss_fn)(params, x)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

eval_batch = digits[0]   # next(loader)[2]
def train_loop(params, opt_state):
    for i in range(10000):
        _, _, img = next(loader)
        img = digits[i]
        loss, params, opt_state = update_fn(params, img, opt_state)
        eval_pred = decoder(encoder(eval_batch, params), params)
        # eval_imgs = make_halves(eval_batch, eval_pred)
        plot_multiples(eval_pred, 5, info_bar=[f"step : {i:05d}",
                                               f"loss : {loss:.3f}",
                                               f"params : {n_params:,}",
                                               f"latent_dim : {latent_dim}"])
    return params, opt_state

In [244]:
train_loop(params, opt_state)

InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (16, 27, 27, 16) and (16, 784)