# Neuroscope

In [None]:
import syrkis
import jax
from jax import vmap, jit, lax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow_datasets as tfds
from tqdm import tqdm
import time

from src.data import load_subject, make_kfolds

In [None]:
# GLOBALS
cfg = {
    'image_size' :   28,
    'embed_dim'  :   10,
    'batch_size' :   60,
    'kernel_size':    3,
    'channels'   :    1,
    'stride'     :    1,
    'layers'     :    2,
    'lr'         :   1e-3,
    'epochs'     :  100,
    'scale'      :   1e-2,
    'beta'       :   1.0,
}

In [None]:
subject = load_subject('subj07', image_size=cfg['image_size'])
kfolds = make_kfolds(subject, cfg)
loader, eval_loader = next(kfolds)  # type: ignore

In [None]:
mnist = tfds.load('mnist', split='train', shuffle_files=True)
# make it jax
mnist = tfds.as_numpy(mnist)
mnist = jnp.array([x['image'] for x in mnist]) / 255.
mnist = mnist.reshape(-1, cfg['batch_size'], 28, 28, 1)
mnist.shape

## Batch Normalization

In [None]:
@jit
def batch_norm(x, gamma, beta, eps=1e-5):
    # x: batch x height x width x channels
    mean = jnp.mean(x, axis=(0, 1, 2), keepdims=True)
    var = jnp.var(x, axis=(0, 1, 2), keepdims=True)
    x = (x - mean) / jnp.sqrt(var + eps)
    x = gamma * x + beta
    return x

def init_batch_norm(channels):
    gamma = jnp.ones((1, 1, 1, channels))
    beta = jnp.zeros((1, 1, 1, channels))
    return gamma, beta

In [None]:
# Global constants for common parameters
DIMENSION_NUMBERS = ("NHWC", "HWIO", "NHWC")

@jit
def conv2d(x, w):
    return jax.lax.conv_general_dilated(
        x, w, 
        window_strides=(cfg['stride'], cfg['stride']),
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS)

@jit
def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):
    # Assuming x has shape (batch, height, width, channels)
    b, h, w, c = x.shape
    x = x.reshape(b, h, 1, w, 1, c)
    x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
    return x.reshape(b, h * scale_factor, w * scale_factor, c)

@jit
def deconv2d(x, w):
    x_upscaled = upscale_nearest_neighbor(x)
    return lax.conv_transpose(
        x_upscaled, w, 
        strides=(1, 1), 
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS) 

def conv_fn(fn):
    def apply_fn(params, x):
        for w, b, gamma, beta in params[:-1]:
            x = fn(x, w, b)
            x = batch_norm(x, gamma, beta)
            x = jax.nn.gelu(x)
        w, b, gamma, beta = params[-1]
        x = fn(x, w, b)
        return x
    return apply_fn

conv   = conv_fn(lambda x, w, b: conv2d(x, w) + b)
deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)

def init_conv_params(rng, channels, kernel_size, scale, deconv=False):
    rng, key1, key2 = jax.random.split(rng, 3)
    out_channels = channels if deconv else channels * 2
    in_channels  = channels if not deconv else channels * 2
    w_shape = (kernel_size, kernel_size, in_channels, out_channels)
    b_shape = (out_channels,)
    w = scale * jax.random.normal(key1, w_shape)
    b = scale * jax.random.normal(key2, b_shape)
    gamma, beta = init_batch_norm(out_channels)
    return w, b, gamma, beta

def init_conv_layers(rng, channels, kernel_size, layers, scale, deconv=False):
    rngs = jax.random.split(rng, layers)
    params = []
    for idx, rng in enumerate(rngs):
        params.append(init_conv_params(rng, channels * 2 ** idx, kernel_size, scale, deconv))
    return params[::-1] if deconv else params

In [None]:
def encoder(params, x):
    # x: (batch, height, width, channels)
    z = conv(params, x)
    z = z.reshape(z.shape[0], -1)
    return z

def decoder(params, z):
    # z: (batch, latent_dim)
    s = int(np.sqrt(z.shape[1] / (cfg['channels'] * 2 ** cfg['layers'])))
    z = z.reshape(z.shape[0], s, s, cfg['channels'] * 2 ** cfg['layers'])
    z = deconv(params, z)
    return z


def init_linear_params(rng, in_dim, out_dim, scale):
    key1, key2 = jax.random.split(rng, 2)
    w_shape = (in_dim, out_dim)
    b_shape = (out_dim,)
    w = scale * jax.random.normal(key1, w_shape)
    b = scale * jax.random.normal(key2, b_shape)
    return w, b

def forward_linear(params, x):
    w, b = params
    return x @ w + b

def init_fn(rng, cfg):
    # al 3s are for RGB channels
    latent_dim = cfg['channels'] * 2 ** cfg['layers']* (cfg['image_size']// cfg['stride']** cfg['layers']) ** 2  # make stride dependent
    rng, key1, key2, key3, key4, key5 = jax.random.split(rng, 6)
    params = {
        'encoder_conv': init_conv_layers(key1, cfg['channels'], cfg['kernel_size'], cfg['layers'], cfg['scale']),
        'decoder_fc': init_linear_params(key2, cfg['embed_dim'], latent_dim, cfg['scale']),
        'decoder_conv': init_conv_layers(key3, cfg['channels'], cfg['kernel_size'], cfg['layers'], cfg['scale'], deconv=True),
        'linear_mu': init_linear_params(key4, latent_dim, cfg['embed_dim'], cfg['scale']),
        'linear_logvar': init_linear_params(key5, latent_dim, cfg['embed_dim'], cfg['scale']),
    }
    return params


def apply_fn(params, x, rng=None):
    # x: (batch, height, width, channels)
    z = encoder(params['encoder_conv'], x)
    mu, logvar = forward_linear(params['linear_mu'], z), forward_linear(params['linear_logvar'], z)
    z = reparametrize(mu, logvar, rng) if rng is not None else mu
    print(z.mean(), z.std())
    z = forward_linear(params['decoder_fc'], z)
    z = jax.nn.gelu(z)
    x_hat = decoder(params['decoder_conv'], z)
    x_hat = jax.nn.sigmoid(x_hat)
    return x_hat, mu, logvar


def reparametrize(mu, logvar, rng):
    # mu, logvar: (batch, latent_dim)
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(rng, std.shape)
    return mu + eps * std



In [None]:
@jit
def loss_fn(params, x, rng):
    # x: (batch, height, width, channels)
    x_hat, mu, logvar = apply_fn(params, x, rng)
    kl_loss = kl_divergence(mu, logvar) * cfg['beta']
    recon_loss = jnp.mean((x - x_hat) ** 2, axis=(1, 2, 3)).mean()
    return recon_loss + kl_loss

def kl_divergence(mu, logvar):
    return -0.5 * jnp.sum(1 + logvar - mu ** 2 - jnp.exp(logvar), axis=1).mean()

In [None]:
rng, key = jax.random.split(jax.random.PRNGKey(0))
params = init_fn(rng, cfg)
n_params = syrkis.training.n_params(params)

In [None]:
opt = optax.adamw(cfg['lr'])
opt_state = opt.init(params)
grad_fn = jax.value_and_grad(loss_fn)

@jit
def update_fn(params, x, opt_state, rng):
    loss, grads = grad_fn(params, x, rng)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state

# eval_batch = next(loader)[2]
eval_batch = mnist[0]
generation_batch = jax.random.normal(key, (encoder(params['encoder_conv'], eval_batch).shape))
def train_loop(params, opt_state, rng):
    # toc = time.time()
    for epoch in range(1, cfg['epochs']):
        # tic = time.time()
        for i in range(8_000 // cfg['batch_size']):
            #_, _, img = next(loader)
            img = mnist[i]
            rng, key = jax.random.split(rng)
            loss, params, opt_state = update_fn(params, img, opt_state, key)
            eval_imgs = apply_fn(params, eval_batch)[0]
            gen_imgs = decoder(params['decoder_conv'], generation_batch)
            imgs = jnp.concatenate((eval_imgs[:6], eval_batch[:6], gen_imgs[:6]), axis=0)
            syrkis.training.plot_multiples(imgs, n_rows=3, info_bar=[
                f"mse loss : {loss:.3f}",
                f"embed_dim : {cfg['embed_dim']}",
                f"n_layers : {cfg['layers']}",
                f"epoch : {epoch + 1}",
                f"batch : {i + 1}",
                f"params : {n_params:,}",
                # f"eta : {(time.time() - tic) * (8_000 // cfg['batch_size'] - i) / 60:.2f} min",
                ])
        # toc = time.time()
    return params, opt_state

In [None]:
train_loop(params, opt_state, rng)