# Neuroscope

Jupyter workspace for the neuroscope project

In [107]:
import jax
from jax import vmap, jit, lax, random, grad, value_and_grad
import jax.numpy as jnp
import optax
from jax import config
import jraph
import numpy as np

import wandb
import numpy as np
from functools import partial
from tqdm import tqdm
import time
import seaborn as sns
import matplotlib.pyplot as plt
import os, pickle

import syrkis
from src.data import load_subjects, make_kfolds
from src.fmri import get_bold_with_coords_and_faces as get_mesh

In [108]:
# GLOBALS
rng = random.PRNGKey(0)
cfg = syrkis.train.load_config()['neuroscope']
opt = optax.adamw(learning_rate=cfg['lr'])
config.update("jax_platform_name", "cpu")

## Data

In [109]:
subjects = load_subjects(['subj05', 'subj07'], cfg['image_size'])

## fmri

In [110]:
def make_samples(sample, subject, n_subjects):
    lhs, rhs, imgs = sample
    fmri = []
    for idx, lh, rh in tqdm(zip(range(n_subjects), lhs, rhs)):
        lh_graph  = make_graph(*get_mesh(lh, subject, 'lh'))
        rh_graph  = make_graph(*get_mesh(rh, subject, 'rh'))
        graph     = combine_hems(lh_graph, rh_graph)
        fmri.append(graph)
    # fmri = jnp.array(fmri)
    return fmri, imgs

def combine_hems(lh_graph, rh_graph):
    # Concatenate node feature
    # number of 0s we need to add to have nodes be power of 2
    n_node = lh_graph.n_node + rh_graph.n_node
    padding_size = (2 ** jnp.ceil(jnp.log2(n_node)) - n_node).astype(int)[0]
    padding = jnp.zeros((padding_size, 1))
    nodes = jnp.concatenate([lh_graph.nodes, rh_graph.nodes, padding], axis=0)

    # Adjust senders and receivers indices for right hemisphere
    rh_offset = lh_graph.n_node
    rh_senders = rh_graph.senders + rh_offset
    rh_receivers = rh_graph.receivers + rh_offset

    senders = jnp.concatenate([lh_graph.senders, rh_senders], axis=0)
    receivers = jnp.concatenate([lh_graph.receivers, rh_receivers], axis=0)

    n_node += + padding_size
    n_edge  = lh_graph.n_edge + rh_graph.n_edge
    
    return jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,
                             nodes=nodes, senders=senders, receivers=receivers)


@jit
def make_graph(coords, features, faces):
    senders   = jnp.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0)
    receivers = jnp.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0)
    n_node = jnp.array([features.shape[0]])
    n_edge = jnp.array([senders.shape[0]])

    graph = jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,
                              nodes=features[:, None], senders=senders, receivers=receivers)
    return graph

In [111]:
def update_node_fn(node_features, params):
    weights, biases = params
    return jnp.dot(node_features, weights) + biases

def apply_graph_convolution(graph, params):
    # Define the graph convolution layer
    gcn_layer = jraph.GraphConvolution(
        update_node_fn=partial(update_node_fn, params=params),
        aggregate_nodes_fn=jraph.segment_sum,
        add_self_edges=True,
        symmetric_normalization=True,
    )

    # Apply the graph convolution layer to the graph
    return gcn_layer(graph)

In [112]:

def pool_fn(graph, pool_size):
    # Reshape and pool node features
    num_nodes, num_features = graph.nodes.shape
    pooled_features = graph.nodes.reshape(-1, pool_size, num_features)
    pooled_features = jnp.mean(pooled_features, axis=1)

    # Update edges for the pooled graph
    # Create a mapping from old node indices to new pooled node indices
    node_mapping = np.repeat(np.arange(len(pooled_features)), pool_size)[:num_nodes]

    # Update senders and receivers based on the node mapping
    pooled_senders = node_mapping[graph.senders]
    pooled_receivers = node_mapping[graph.receivers]

    # Filter out self-loops created by pooling
    edge_mask = pooled_senders != pooled_receivers
    pooled_senders = pooled_senders[edge_mask]
    pooled_receivers = pooled_receivers[edge_mask]

    # Update the graph with pooled nodes and edges
    pooled_graph = graph._replace(nodes=pooled_features, senders=pooled_senders, receivers=pooled_receivers)
    return pooled_graph

# Example usage
# Assume 'graph' is your input jraph.GraphsTuple

## utils

In [113]:
def latent_side_fn(cfg):
    return cfg['image_size'] // cfg['stride'] ** cfg['conv_layers']

def latent_dim_fn(cfg):
    # should return the size of the loatente dim depending on initial image size, stride, and number of layers, and channels
    channels = cfg['chan_start']
    # calulate latent channels
    latent_channels = int(channels * (cfg['conv_branch'] ** (cfg['conv_layers'] - 1)))
    # calculate latent side
    latent_side = latent_side_fn(cfg)
    # calculate latent dim
    latent_dim = latent_channels * latent_side ** 2
    return latent_dim
        
latent_dim = latent_dim_fn(cfg)
latent_side = latent_side_fn(cfg)
print(latent_side)
print(latent_dim)

2
2048


## Batch norm

In [114]:
@jit
def batch_norm(x, gamma, beta, eps=1e-5):
    if not cfg['batch_norm']:
        return x
    # x: batch x height x width x channels
    axis = tuple(range(len(x.shape) - 1))
    mean = jnp.mean(x, axis=axis, keepdims=True)
    var = jnp.var(x, axis=axis, keepdims=True)
    x = (x - mean) / jnp.sqrt(var + eps)
    x = gamma * x + beta
    return x

def init_batch_norm(shape):
    shape = [1 for _ in range(len(shape) - 1)] + [shape[-1]]
    shape = tuple(shape)
    gamma = jnp.ones(shape)
    beta = jnp.zeros(shape)
    return gamma, beta

## Linear layer

In [115]:
def init_linear_layer(rng, in_dim, out_dim, tensor_dim):
    # tensor dim is for having fmri embedding in same array, but seperate layers.
    rng, key = jax.random.split(rng, 2)
    w_shape = (in_dim, out_dim)
    b_shape = (out_dim,)
    w = syrkis.train.glorot_init(key, w_shape)
    if tensor_dim > 0:
        w = w.reshape((-1, out_dim, tensor_dim))
    b = jnp.zeros(b_shape)
    gamma, beta = init_batch_norm(b_shape)
    return w, b, gamma, beta

def linear(params, x):
    for idx, (w, b, gamma, beta) in enumerate(params):
        x = x @ w + b
        x = jax.nn.gelu(x) if idx != len(params) - 1 else x
        x = batch_norm(x, gamma, beta) if idx != len(params) - 1 else x
    return x

## Convolutions

In [116]:
# Global constants for common parameters
DIMENSION_NUMBERS = ("NHWC", "HWIO", "NHWC")


@jit
def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):
    # Assuming x has shape (batch, height, width, channels)
    b, h, w, c = x.shape
    x = x.reshape(b, h, 1, w, 1, c)
    x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
    return x.reshape(b, h * scale_factor, w * scale_factor, c)


@jit
def deconv2d(x, w):
    x_upscaled = upscale_nearest_neighbor(x)
    return lax.conv_transpose(
        x_upscaled, w, 
        strides=(1, 1), 
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS) 


def conv_fn(fn):
    def apply_fn(params, x):
        for i, (w, b) in enumerate(params):
            x = fn(x, w, b)
            # x = batch_norm(x, gamma, beta) if i != len(params) - 1 else x
            # x = jax.nn.tanh(x) if i != len(params) - 1 else x
            x = jax.nn.gelu(x) if i != len(params) - 1 else x
        return x
    return apply_fn


def manual_batch_graphs(graph_list):
    # Initialize lists to hold the concatenated components
    all_nodes = []
    all_senders = []
    all_receivers = []
    offset = 0

    for graph in graph_list:
        all_nodes.append(graph.nodes)
        all_senders.append(graph.senders + offset)
        all_receivers.append(graph.receivers + offset)
        offset += graph.nodes.shape[0]

    # Concatenate all components
    batched_nodes = jnp.concatenate(all_nodes, axis=0)
    batched_senders = jnp.concatenate(all_senders, axis=0)
    batched_receivers = jnp.concatenate(all_receivers, axis=0)

    # Create and return the combined GraphsTuple
    return jraph.GraphsTuple(
        n_node=jnp.array(batched_nodes.shape[0]),
        n_edge=jnp.array(batched_senders.shape[0]),
        nodes=batched_nodes,
        senders=batched_senders,
        receivers=batched_receivers,
        edges=None,  # or concatenate edges if your graph has them
        globals=None  # or concatenate globals if your graph has them
    )


deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)


def init_conv_params(rng, in_chan, out_chan, cfg, deconv=False):
    if deconv:
        in_chan, out_chan = out_chan, in_chan
    rng, key = jax.random.split(rng, 2)
    w_shape = (cfg['kernel_size'], cfg['kernel_size'], in_chan, out_chan)
    w = syrkis.train.glorot_init(key, w_shape)
    b = jnp.zeros((out_chan,))
    gamma, beta = init_batch_norm(b.shape)
    return w, b # , gamma, beta


def init_conv_layers(rng, cfg, deconv=False):
    rngs = jax.random.split(rng, cfg['conv_layers'])
    params = []
    for idx, rng in enumerate(rngs):
        in_chan = cfg['in_chans'] if idx == 0 else cfg['chan_start'] * (cfg['conv_branch'] ** (idx - 1))
        out_chan = cfg['chan_start'] * (cfg['conv_branch'] ** idx)
        params.append(init_conv_params(rng, in_chan, out_chan, cfg, deconv))
    return params[::-1] if deconv else params

## Model

In [117]:
def dropout(x, rate, rng):
    rate = 1.0 - rate
    keep = random.bernoulli(rng, rate, x.shape)
    return jnp.where(keep, x / rate, 0)


def matmul_slice(A, B_slice):
    return jnp.dot(A, B_slice)
batched_matmul = vmap(matmul_slice, in_axes=(None, 2))


@jit
def decode_fn(params, z, rng=None):
    z = deconv(params['deconv'], z)
    z = jax.nn.sigmoid(z)
    return z


def apply_fn(params, fmri):
    z = fmri

    # Apply graph convolution layers
    for i, p in enumerate(params['gcn']):
        z = apply_graph_convolution(z, p)
        z = pool_fn(z, 4)
    z = z.nodes.flatten()

    # Apply dense layer
    for i, p in enumerate(params['fcs']):
        z = jnp.dot(z, p[0]) + p[1]
        z = jax.nn.relu(z)

    # Apply image deconv layers to make image
    z = z.reshape(1, 2, 2, -1)
    z = deconv(params['cnn'], z)
    z = jax.nn.sigmoid(z)
    return z



# This function returns the total loss and its components (recon and KL losses).
def loss_fn(params, fmri, img):
    img_hat = apply_fn(params, fmri)
    recon_loss = jnp.mean((img - img_hat) ** 2)
    return recon_loss

def update_fn(params, fmri, img, subj, opt_state):
    # Get the loss, aux data (recon_loss, kl_loss), and gradients
    loss, grads = value_and_grad(loss_fn)(params, fmri, img, subj)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


def init_fn(rng, cfg, scale=1e-2):
    gcn = []
    c_in = 1
    for i in range(cfg['gcn_layers']):
        rng, key = random.split(rng)
        c_out = 2 ** i 
        gcn.append((random.normal(key, (c_in, c_out)) * scale, jnp.zeros((c_out,))))
        c_in = c_out
    
    rng, key = random.split(rng)
    cnn = init_conv_layers(key, cfg, deconv=True)
    rng, key = random.split(rng)
    fcs = [
        (random.normal(key, (1024, cfg['latent_dim'])) * scale, jnp.zeros((cfg['latent_dim'],))),
        (random.normal(key, (cfg['latent_dim'], 2048)) * scale, jnp.zeros((2048,))),
    ]
    return {'gcn': gcn, 'fcs': fcs, 'cnn': cnn}


## Training

In [118]:
cfg = syrkis.train.load_config()['neuroscope']
rng = jax.random.PRNGKey(0)
params = init_fn(rng, cfg)
n_params = syrkis.train.n_params(params)
opt_state = opt.init(params)
syrkis.train.n_params(params)

1963083

In [121]:
fmri, imgs = make_samples(subjects['subj05'], 'subj05', 100)
fmri = jnp.array(fmri)
paraply_fn = jit(vmap(apply_fn, in_axes=(None, 0)))
for i in tqdm(range(0, 100, 2)):
    x, y = fmri[i:i+2], imgs[i:i+2]
    y_hat = paraply_fn(params, x)
    loss = loss_fn(params, x, y)
    

100it [00:10,  9.45it/s]


ValueError: All input arrays must have the same shape.

In [120]:

batched_graph = jraph.batch(fmri)

XlaRuntimeError: UNKNOWN: /var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: error: failed to legalize operation 'mhlo.pad'
/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: called from
/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_88449/498112577.py:1:16: note: see current operation: %89 = "mhlo.pad"(%88, %1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<1> : tensor<1xi64>} : (tensor<2xsi32>, tensor<si32>) -> tensor<3xsi32>
