# Neuroscope

Jupyter workspace for the neuroscope project

In [67]:
import jax
from jax import vmap, jit, lax, random, grad, value_and_grad
import jax.numpy as jnp
import optax

import numpy as np
from functools import partial
from tqdm import tqdm
import time
import seaborn as sns
import matplotlib.pyplot as plt

import syrkis
from src.data import load_subject, make_kfolds

In [68]:
# GLOBALS
cfg = syrkis.training.load_config()

## Data

In [69]:
# subject = load_subject('subj07', image_size=cfg['image_size'])
# kfolds = make_kfolds(subject, cfg)
# neuroscope_train_batches, neuroscope_eval_batches = next(kfolds)  # type: ignore

In [70]:
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 10)
mnist = syrkis.data.mnist() if 'mnist' not in locals() else eval('mnist')
mnist_x, mnist_y = mnist
mnist_x_batches = mnist_x.reshape(-1, 60, 28, 28, 1)
mnist_y_batches = mnist_y.reshape(-1, 60)

## Batch norm

In [71]:
@jit
def batch_norm(x, gamma, beta, eps=1e-5):
    # x: batch x height x width x channels
    mean = jnp.mean(x, axis=(0, 1, 2), keepdims=True)
    var = jnp.var(x, axis=(0, 1, 2), keepdims=True)
    x = (x - mean) / jnp.sqrt(var + eps)
    x = gamma * x + beta
    return x

def init_batch_norm(channels):
    gamma = jnp.ones((1, 1, 1, channels))
    beta = jnp.zeros((1, 1, 1, channels))
    return gamma, beta

def batch_norm_test():
    gamma, beta = init_batch_norm(2)
    x = jnp.array([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]])
    normalized_x = batch_norm(x, gamma, beta)
    means = jnp.mean(normalized_x, axis=(0, 1, 2))
    stds = jnp.std(normalized_x, axis=(0, 1, 2))
    assert jnp.allclose(means, jnp.zeros(2))
    assert jnp.allclose(stds, jnp.ones(2))

batch_norm_test()

## Linear layer

In [72]:
def init_linear_layer(rng, in_dim, out_dim, scale):
    keys = jax.random.split(rng, 2)
    w_shape = (in_dim, out_dim)
    b_shape = (out_dim,)
    w = scale * jax.random.normal(keys[0], w_shape)
    b = scale * jax.random.normal(keys[1], b_shape)
    return w, b

def init_linear_layers(rng, in_dim, out_dim, cfg):
    # first layer goes from in_dim to embed_dim, rest are embed to embed, and last is embed to out
    rngs = jax.random.split(rng, cfg['layers'])
    params = []
    for idx, rng in enumerate(rngs):
        layer_in_dim  = cfg['embed_dim'] if idx != 0                 else in_dim
        layer_out_dim = cfg['embed_dim'] if idx != cfg['layers'] - 1 else out_dim
        params.append(init_linear_layer(rng, layer_in_dim, layer_out_dim, cfg['scale']))
    return params

def linear(params, x):
    for w, b in params[:-1]:
        x = jax.nn.gelu(x @ w + b)
    w, b = params[-1]
    return x @ w + b

def test_linear():
    rng = jax.random.PRNGKey(0)
    x = jnp.array([1.0, 2.0])
    params = init_linear_layers(rng, 2, 4, cfg)
    y = linear(params, x)
    assert y.shape == (4,)
    print(y)

test_linear()

[-0.02130392 -0.0382999   0.03523295 -0.0305739 ]


## Convolutions

In [90]:
# Global constants for common parameters
DIMENSION_NUMBERS = ("NHWC", "HWIO", "NHWC")

@jit
def conv2d(x, w):
    return jax.lax.conv_general_dilated(
        x, w, 
        window_strides=(cfg['stride'], cfg['stride']),
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS)

@jit
def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):
    # Assuming x has shape (batch, height, width, channels)
    b, h, w, c = x.shape
    x = x.reshape(b, h, 1, w, 1, c)
    x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
    return x.reshape(b, h * scale_factor, w * scale_factor, c)

@jit
def deconv2d(x, w):
    x_upscaled = upscale_nearest_neighbor(x)
    return lax.conv_transpose(
        x_upscaled, w, 
        strides=(1, 1), 
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS) 

def conv_fn(fn):
    def apply_fn(params, x):
        for w, b, gamma, beta in params[:-1]:
            x = fn(x, w, b)
            # x = batch_norm(x, gamma, beta)
            # x = jax.nn.gelu(x)
        w, b, gamma, beta = params[-1]
        x = fn(x, w, b)
        # x = batch_norm(x, gamma, beta)
        return x
    return apply_fn

conv   = conv_fn(lambda x, w, b: conv2d(x, w) + b)
deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)

def init_conv_params(rng, cfg, deconv=False):
    rng, key1, key2 = jax.random.split(rng, 3)
    out_channels = cfg['channels'] if deconv else cfg['channels'] * 2
    in_channels  = cfg['channels'] if not deconv else cfg['channels'] * 2
    w_shape = (cfg['kernel_size'], cfg['kernel_size'], in_channels, out_channels)
    b_shape = (out_channels,)
    w = cfg['scale'] * jax.random.normal(key1, w_shape)
    b = cfg['scale'] * jax.random.normal(key2, b_shape)
    gamma, beta = init_batch_norm(out_channels)
    return w, b, gamma, beta

def init_conv_layers(rng, cfg, deconv=False):
    rngs = jax.random.split(rng, cfg['layers'])
    params = []
    for idx, rng in enumerate(rngs):
        params.append(init_conv_params(rng, cfg, deconv))
    return params[::-1] if deconv else params

def calculate_latent_dim(cfg):
    # should return the size of the latente dim depending on initial image size, stride, and number of layers, and channels
    image_size = cfg['image_size']
    stride = cfg['stride']
    layers = cfg['layers']
    channels = cfg['channels']
    latent_dim = 2 * image_size ** 2 * channels // (stride ** 2) ** layers
    # Multiply by 2 because of bug (this is a temporary fix)
    return latent_dim

print(calculate_latent_dim(cfg))

1568


## Model

In [100]:
def init_params(rng, cfg):
    latent_dim = calculate_latent_dim(cfg)
    print('latent_dim', latent_dim)
    conv_params   = init_conv_layers(rng, cfg)                            # conv layers
    linear_params = init_linear_layers(rng, latent_dim, latent_dim, cfg)  # linear layers
    deconv_params = init_conv_layers(rng, cfg, deconv=True)               # deconv layers
    return {'conv': conv_params, 'deconv': deconv_params, 'linear': linear_params}

def reparametrize(mu, logvar, rng):
    # mu, logvar: (batch, latent_dim)
    std = jnp.exp(0.5 * logvar)
    eps = jax.random.normal(rng, std.shape)
    return mu + eps * std

def apply_fn(params, x):
    x = conv(params['conv'], x)
    # mu, logvar = linear(params['linear'], x)
    # z = reparametrize(mu, logvar)
    z = linear(params['linear'], x.reshape(cfg['batch_size'], -1))
    x_hat = deconv(params['deconv'], z.reshape(x.shape))
    return x_hat  # , mu, logvar

def model_test():
    rng = jax.random.PRNGKey(0)
    x = jnp.ones((cfg['batch_size'], cfg['image_size'], cfg['image_size'], cfg['channels']))
    params = init_params(rng, cfg)
    x_hat = apply_fn(params, x)
    assert x_hat.shape == x.shape

model_test()

latent_dim 1568


## Training

In [106]:
def kl_divergence(mu, logvar):
    return -0.5 * jnp.sum(1 + logvar - mu ** 2 - jnp.exp(logvar), axis=1).mean()

def cross_entropy(logits, labels, epsilon=1e-12):
    max_logits = jnp.max(logits, axis=1, keepdims=True)
    stabilized_logits = logits - max_logits
    log_sum_exp = jnp.log(jnp.sum(jnp.exp(stabilized_logits), axis=1, keepdims=True) + epsilon)
    labels_one_hot = jnp.eye(logits.shape[1])[labels]
    loss = -jnp.mean(labels_one_hot * (stabilized_logits - log_sum_exp))
    return loss

def mean_squared_error(logits, labels):
    return jnp.mean((logits - labels) ** 2)

def loss_fn(params, x, y):
    logits = apply_fn(params, x)
    return mean_squared_error(logits, y)

def update_fn(opt, opt_state, params, x, y=None):
    if y is None:
        loss, grads = jax.value_and_grad(loss_fn)(params, x)
    else:
        loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return loss, params, opt_state


In [114]:
rng = jax.random.PRNGKey(0)
opt = optax.adam(1e-3)
params = init_params(rng, cfg)
opt_state = opt.init(params)
update = jit(partial(update_fn, opt))
pbar = tqdm(range(1000))
z_seed = jax.random.normal(rng, (1, calculate_latent_dim(cfg)))

for i in pbar:
    x = mnist_x_batches[i]
    loss, params, opt_state = update(opt_state, params, x, x)
    recon = apply_fn(params, mnist_x_batches[0])
    pbar.set_description(f'loss: {loss:.4f}')
    syrkis.training.plot_multiples(recon)

loss: 0.0051: 100%|██████████| 1000/1000 [00:30<00:00, 33.21it/s]


## Testing

In [39]:
rng = jax.random.PRNGKey(0)
keys = jax.random.split(rng, 10)
mnist = syrkis.data.mnist() if 'mnist' not in locals() else eval('mnist')
mnist_x, mnist_y = mnist
mnist_x_batches = mnist_x.reshape(-1, 60, 28, 28, 1)
mnist_y_batches = mnist_y.reshape(-1, 60)

loss : 0.014: 100%|██████████| 1000/1000 [00:06<00:00, 147.47it/s]
