# Neuroscope

Jupyter workspace for the neuroscope project

In [1]:
import jax
from jax import vmap, jit, lax, random, grad, value_and_grad
import jax.numpy as jnp
import optax
from jax import config
import jraph
import numpy as np

import wandb
import numpy as np
from functools import partial
from tqdm import tqdm
import time
import seaborn as sns
import matplotlib.pyplot as plt
import os, pickle

import syrkis
from src.data import load_subjects, make_kfolds
from src.fmri import get_bold_with_coords_and_faces as get_mesh

In [97]:
# GLOBALS
rng = random.PRNGKey(0)
cfg = syrkis.train.load_config()['neuroscope']
opt = optax.adamw(learning_rate=cfg['lr'])
batch_size = 32

## Data

In [98]:
subjects = load_subjects(['subj05', 'subj07'], cfg['image_size'])

## fmri

In [99]:
def make_samples(rng, sample, subject, n_subjects=64):
    lhs, rhs, imgs = sample
    # shuffle
    idxs = random.shuffle(rng, jnp.arange(len(lhs)))
    lhs, rhs, imgs = lhs[idxs], rhs[idxs], imgs[idxs]
    fmri = []
    for idx, lh, rh in tqdm(zip(range(n_subjects), lhs, rhs)):
        lh_graph  = make_graph(*get_mesh(lh, subject, 'lh'))
        rh_graph  = make_graph(*get_mesh(rh, subject, 'rh'))
        graph     = combine_hems(lh_graph, rh_graph)
        fmri.append(graph)
    # fmri = jnp.array(fmri)
    return fmri, imgs

def combine_hems(lh_graph, rh_graph):
    # Concatenate node feature
    # number of 0s we need to add to have nodes be power of 2
    n_node = lh_graph.n_node + rh_graph.n_node
    padding_size = (2 ** jnp.ceil(jnp.log2(n_node)) - n_node).astype(int)[0]
    padding = jnp.zeros((padding_size))
    nodes = jnp.concatenate([lh_graph.nodes, rh_graph.nodes, padding], axis=0)[:, None]

    # Adjust senders and receivers indices for right hemisphere
    rh_offset = lh_graph.n_node
    rh_senders = rh_graph.senders + rh_offset
    rh_receivers = rh_graph.receivers + rh_offset

    senders = jnp.concatenate([lh_graph.senders, rh_senders], axis=0)
    receivers = jnp.concatenate([lh_graph.receivers, rh_receivers], axis=0)

    n_node += + padding_size
    n_edge  = lh_graph.n_edge + rh_graph.n_edge

    # make graph nodes power of 2 by padding with 0s
    return jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,
                             nodes=nodes, senders=senders, receivers=receivers)


@jit
def make_graph(coords, features, faces):
    senders   = jnp.concatenate([faces[:, 0], faces[:, 1], faces[:, 2]], axis=0)
    receivers = jnp.concatenate([faces[:, 1], faces[:, 2], faces[:, 0]], axis=0)
    n_node = jnp.array([features.shape[0]])
    n_edge = jnp.array([senders.shape[0]])
    graph = jraph.GraphsTuple(n_node=n_node, n_edge=n_edge, edges=None, globals=None,
                              nodes=features,senders=senders, receivers=receivers)
    return graph

In [100]:
def update_node_fn(node_features, params):
    weights, biases = params
    return jnp.dot(node_features, weights) + biases

def apply_graph_convolution(graph, params):
    # Define the graph convolution layer
    gcn_layer = jraph.GraphConvolution(
        update_node_fn=partial(update_node_fn, params=params),
        aggregate_nodes_fn=jraph.segment_sum,
        add_self_edges=True,
        symmetric_normalization=True,
    )

    # Apply the graph convolution layer to the graph
    return gcn_layer(graph)

In [101]:

def manual_batch_graphs(graph_list):
    # Initialize lists to hold the concatenated components
    all_nodes = []
    all_senders = []
    all_receivers = []
    offset = 0

    for graph in graph_list:
        all_nodes.append(graph.nodes)
        all_senders.append(graph.senders + offset)
        all_receivers.append(graph.receivers + offset)
        offset += graph.nodes.shape[0]

    # Concatenate all components
    batched_nodes = jnp.concatenate(all_nodes, axis=0)
    batched_senders = jnp.concatenate(all_senders, axis=0)
    batched_receivers = jnp.concatenate(all_receivers, axis=0)

    print(batched_nodes.shape, batched_senders.shape, batched_receivers.shape)

    # Create and return the combined GraphsTuple
    return jraph.GraphsTuple(
        n_node=jnp.array(batched_nodes.shape[0]),
        n_edge=jnp.array(batched_senders.shape[0]),
        nodes=batched_nodes,
        senders=batched_senders,
        receivers=batched_receivers,
        edges=None,  # or concatenate edges if your graph has them
        globals=None  # or concatenate globals if your graph has them
    )

def get_batches(fmri, imgs):
    n_batches = len(fmri) // batch_size
    batches = [manual_batch_graphs(fmri[i*batch_size:(i+1)*batch_size]) for i in range(n_batches)]
    while True:
        for batch in batches:
            yield batch, imgs

In [121]:
def pool_fn(graph, edge_mask, pool_size=4):
    num_nodes, num_features = graph.nodes.shape

    # Ensure the number of nodes is divisible by pool_size
    if num_nodes % pool_size != 0:
        raise ValueError(f"Number of nodes ({num_nodes}) must be divisible by pool_size ({pool_size}).")

    # Pooling the nodes
    pooled_features = graph.nodes.reshape(-1, pool_size, num_features)
    pooled_features = jnp.mean(pooled_features, axis=1)

    # Calculate new indices for senders and receivers
    node_mapping = jnp.repeat(jnp.arange(num_nodes // pool_size), pool_size)
    new_senders = node_mapping[graph.senders]
    new_receivers = node_mapping[graph.receivers]

    # Use the precomputed edge_mask to filter edges
    new_senders = new_senders[edge_mask]
    new_receivers = new_receivers[edge_mask]

    # Update the graph with pooled nodes and edges
    pooled_graph = graph._replace(nodes=pooled_features, senders=new_senders, receivers=new_receivers)
    return pooled_graph


def generate_valid_edge_mask(graph, pool_size):
    num_nodes = graph.nodes.shape[0] // pool_size
    node_mapping = jnp.repeat(jnp.arange(num_nodes), pool_size)

    # Calculate new indices for senders and receivers
    pooled_senders = node_mapping[graph.senders]
    pooled_receivers = node_mapping[graph.receivers]

    # Create a mask for valid edges (non-self-loops)
    valid_edge_mask = pooled_senders != pooled_receivers
    return valid_edge_mask


def generate_valid_edge_masks(graph, pool_depth):
    lst = []
    for i in range(pool_depth):
        pool_size = 2 ** i
        if i > 0:  # Pool the graph for levels > 0
            graph = pool_fn(graph, lst[i-1], pool_size)

        mask = generate_valid_edge_mask(graph, pool_size)
        lst.append(mask)
        print(f'Pooling level {i}, Graph shapes:', graph.nodes.shape, graph.senders.shape, graph.receivers.shape)
        print(f'Mask shape:', mask.shape)

    return lst


def test_pooling_function(graph, pool_depth):
    print("Original graph shapes:", graph.nodes.shape, graph.senders.shape, graph.receivers.shape)

    for i in range(pool_depth):
        pool_size = 4 ** i
        if i > 0:  # Pool the graph for levels > 0
            graph = pool_fn(graph, lst[i-1], pool_size)

        mask = generate_valid_edge_mask(graph, pool_size)
        print(f"Pooling level {i}, pool_size: {pool_size}")
        print("Mask shape:", mask.shape)

        # Check if the mask and the number of edges align
        if graph.senders.shape[0] != mask.shape[0]:
            print(f"Error at pooling level {i}: Edge count does not match mask size.")
        else:
            print(f"Pooling level {i} passed.")

    return graph


# Example usage
# fmri, imgs = make_samples(rng, subjects['subj05'], 'subj05')
batches = get_batches(fmri, imgs)
graph = next(batches)[0]
lst = generate_valid_edge_masks(graph, 4)
test_pooling_function(graph, 4)

(2097152, 1) (7446720,) (7446720,)
(2097152, 1) (7446720,) (7446720,)
Pooling level 0, Graph shapes: (2097152, 1) (7446720,) (7446720,)
Mask shape: (7446720,)
Pooling level 1, Graph shapes: (1048576, 1) (7446720,) (7446720,)
Mask shape: (7446720,)
Pooling level 2, Graph shapes: (262144, 1) (6021344,) (6021344,)
Mask shape: (6021344,)
Pooling level 3, Graph shapes: (32768, 1) (5384128,) (5384128,)
Mask shape: (5384128,)
Original graph shapes: (2097152, 1) (7446720,) (7446720,)
Pooling level 0, pool_size: 1
Mask shape: (7446720,)
Pooling level 0 passed.
Pooling level 1, pool_size: 4
Mask shape: (7446720,)
Pooling level 1 passed.
Pooling level 2, pool_size: 16
Mask shape: (6021344,)
Pooling level 2 passed.
Pooling level 3, pool_size: 64
Mask shape: (5384128,)
Pooling level 3 passed.


GraphsTuple(nodes=Array([[ 3.09292316e-01],
       [ 3.15492958e-01],
       [ 3.26975733e-01],
       [ 2.94483423e-01],
       [ 3.07706267e-01],
       [ 2.47480422e-01],
       [ 2.70877063e-01],
       [ 2.29912490e-01],
       [ 2.33156219e-01],
       [ 1.70832768e-01],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [-1.18194833e-01],
       [-1.06610619e-01],
       [-1.24286771e-01],
       [-1.37095615e-01],
       [-1.70342967e-01],
       [-2.66178429e-01],
       [-2.63474137e-01],
       [-2.47424185e-01],
       [-2.71440744e-01],
       [-1.43191218e-01],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 0.00000000e+00],
       [ 2.92524755e-01],
       [ 2.96035796e-01],
       [ 2.84823030e-01],
       [ 2.86112159e-01],
       [ 3.12877774e-01],
       [ 3.19221199e

## utils

## Batch norm

In [122]:
@jit
def batch_norm(x, gamma, beta, eps=1e-5):
    if not cfg['batch_norm']:
        return x
    # x: batch x height x width x channels
    axis = tuple(range(len(x.shape) - 1))
    mean = jnp.mean(x, axis=axis, keepdims=True)
    var = jnp.var(x, axis=axis, keepdims=True)
    x = (x - mean) / jnp.sqrt(var + eps)
    x = gamma * x + beta
    return x

def init_batch_norm(shape):
    shape = [1 for _ in range(len(shape) - 1)] + [shape[-1]]
    shape = tuple(shape)
    gamma = jnp.ones(shape)
    beta = jnp.zeros(shape)
    return gamma, beta

## Linear layer

In [123]:
def init_linear_layer(rng, in_dim, out_dim, tensor_dim):
    # tensor dim is for having fmri embedding in same array, but seperate layers.
    rng, key = jax.random.split(rng, 2)
    w_shape = (in_dim, out_dim)
    b_shape = (out_dim,)
    w = syrkis.train.glorot_init(key, w_shape)
    if tensor_dim > 0:
        w = w.reshape((-1, out_dim, tensor_dim))
    b = jnp.zeros(b_shape)
    gamma, beta = init_batch_norm(b_shape)
    return w, b, gamma, beta

def linear(params, x):
    for idx, (w, b, gamma, beta) in enumerate(params):
        x = x @ w + b
        x = jax.nn.gelu(x) if idx != len(params) - 1 else x
        x = batch_norm(x, gamma, beta) if idx != len(params) - 1 else x
    return x

## Convolutions

In [124]:
# Global constants for common parameters
DIMENSION_NUMBERS = ("NHWC", "HWIO", "NHWC")


@jit
def upscale_nearest_neighbor(x, scale_factor=cfg['stride']):
    # Assuming x has shape (batch, height, width, channels)
    b, h, w, c = x.shape
    x = x.reshape(b, h, 1, w, 1, c)
    x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
    return x.reshape(b, h * scale_factor, w * scale_factor, c)


@jit
def deconv2d(x, w):
    x_upscaled = upscale_nearest_neighbor(x)
    return lax.conv_transpose(
        x_upscaled, w, 
        strides=(1, 1), 
        padding='SAME',
        dimension_numbers=DIMENSION_NUMBERS) 


def conv_fn(fn):
    def apply_fn(params, x):
        for i, (w, b) in enumerate(params):
            x = fn(x, w, b)
            # x = batch_norm(x, gamma, beta) if i != len(params) - 1 else x
            # x = jax.nn.tanh(x) if i != len(params) - 1 else x
            x = jax.nn.gelu(x) if i != len(params) - 1 else x
        return x
    return apply_fn


deconv = conv_fn(lambda x, w, b: deconv2d(x, w) + b)


def init_conv_params(rng, in_chan, out_chan, cfg, deconv=False):
    if deconv:
        in_chan, out_chan = out_chan, in_chan
    rng, key = jax.random.split(rng, 2)
    w_shape = (cfg['kernel_size'], cfg['kernel_size'], in_chan, out_chan)
    w = syrkis.train.glorot_init(key, w_shape)
    b = jnp.zeros((out_chan,))
    gamma, beta = init_batch_norm(b.shape)
    return w, b # , gamma, beta


def init_conv_layers(rng, cfg, deconv=False):
    rngs = jax.random.split(rng, cfg['conv_layers'])
    params = []
    for idx, rng in enumerate(rngs):
        in_chan = cfg['in_chans'] if idx == 0 else cfg['chan_start'] * (cfg['conv_branch'] ** (idx - 1))
        out_chan = cfg['chan_start'] * (cfg['conv_branch'] ** idx)
        params.append(init_conv_params(rng, in_chan, out_chan, cfg, deconv))
    return params[::-1] if deconv else params

## Model

In [125]:
def dropout(x, rate, rng):
    rate = 1.0 - rate
    keep = random.bernoulli(rng, rate, x.shape)
    return jnp.where(keep, x / rate, 0)


def matmul_slice(A, B_slice):
    return jnp.dot(A, B_slice)
batched_matmul = vmap(matmul_slice, in_axes=(None, 2))


@jit
def decode_fn(params, z, rng=None):
    z = deconv(params['deconv'], z)
    z = jax.nn.sigmoid(z)
    return z


def apply_fn(params, fmri):
    z = fmri
    # Apply graph convolution layers
    for i, p in enumerate(params['gcn']):
        z = pool_fn(z, lst[i])
        z = apply_graph_convolution(z, p)
        print(z.nodes.shape)
    z = z.nodes.reshape(batch_size, -1)
    print(z.shape)
    return z

    # Apply dense layer
    for i, p in enumerate(params['fcs']):
        z = jnp.dot(z, p[0]) + p[1]
        z = jax.nn.relu(z)

    # Apply image deconv layers to make image
    z = z.reshape(batch_size, 2, 2, -1)
    z = deconv(params['cnn'], z)
    z = jax.nn.sigmoid(z)
    return z


# This function returns the total loss and its components (recon and KL losses).
def loss_fn(params, fmri, img):
    img_hat = apply_fn(params, fmri)
    recon_loss = jnp.mean((img - img_hat) ** 2)
    return recon_loss

@jit
def update_fn(params, fmri, img, opt_state):
    # Get the loss, aux data (recon_loss, kl_loss), and gradients
    loss, grads = value_and_grad(loss_fn)(params, fmri, img)
    updates, opt_state = opt.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss


def init_fn(rng, cfg, scale=1e-2):
    gcn = []
    c_in = 1
    for i in range(cfg['gcn_layers']):
        rng, key = random.split(rng)
        c_out = 2 ** i 
        gcn.append((random.normal(key, (c_in, c_out)) * scale, jnp.zeros((c_out,))))
        c_in = c_out
    
    rng, key = random.split(rng)
    cnn = init_conv_layers(key, cfg, deconv=True)
    rng, key = random.split(rng)
    fcs = [
        (random.normal(key, (1024, cfg['latent_dim'])) * scale, jnp.zeros((cfg['latent_dim'],))),
        (random.normal(key, (cfg['latent_dim'], 2048)) * scale, jnp.zeros((2048,))),
    ]
    return {'gcn': gcn, 'fcs': fcs, 'cnn': cnn}


## Training

In [126]:
# fmri, imgs = make_samples(rng, subjects['subj05'], 'subj05', 64)
cfg        = syrkis.train.load_config()['neuroscope']

rng        = jax.random.PRNGKey(0)
params     = init_fn(rng, cfg)
n_params   = syrkis.train.n_params(params)
opt_state  = opt.init(params)
batches    = get_batches(fmri, imgs)
graph      = next(batches)[0]
syrkis.train.n_params(params)
z = apply_fn(params, graph)

(2097152, 1) (7446720,) (7446720,)
(2097152, 1) (7446720,) (7446720,)


loc("jit(select_n)/jit(main)/select_n"("/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_44272/2140735711.py":15:11)): error: 'anec.not_equal_zero' op Invalid configuration for the following reasons: Tensor dimensions N1D1C1H1W7971008 are not within supported range, N[1-65536]D[1-16384]C[1-65536]H[1-16384]W[1-16384].
loc("jit(select_n)/jit(main)/select_n"("/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_44272/2140735711.py":15:11)): error: 'anec.not_equal_zero' op Invalid configuration for the following reasons: Tensor dimensions N1D1C1H1W6152416 are not within supported range, N[1-65536]D[1-16384]C[1-65536]H[1-16384]W[1-16384].


(524288, 1)
(131072, 2)


loc("jit(select_n)/jit(main)/select_n"("/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_44272/2140735711.py":15:11)): error: 'anec.not_equal_zero' op Invalid configuration for the following reasons: Tensor dimensions N1D1C1H1W5416896 are not within supported range, N[1-65536]D[1-16384]C[1-65536]H[1-16384]W[1-16384].


(32768, 4)
(8192, 8)
(32, 2048)


loc("jit(select_n)/jit(main)/select_n"("/var/folders/g5/r0c49hqx3f95cpg6vnztk5s80000gn/T/ipykernel_44272/2140735711.py":15:11)): error: 'anec.not_equal_zero' op Invalid configuration for the following reasons: Tensor dimensions N1D1C1H1W5287872 are not within supported range, N[1-65536]D[1-16384]C[1-65536]H[1-16384]W[1-16384].


In [None]:
for i in (pbar := tqdm(range(10))):
    x, y = next(batches)
    y_hat = apply_fn(params, x)
    """ loss = loss_fn(params, x, y)
    params, opt_state, loss = update_fn(params, x, y, opt_state)
    pbar.set_description(f'loss: {loss:.4f}') """

 10%|█         | 1/10 [00:00<00:02,  3.72it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 20%|██        | 2/10 [00:00<00:01,  4.07it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 30%|███       | 3/10 [00:00<00:01,  4.17it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 40%|████      | 4/10 [00:01<00:01,  3.88it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 50%|█████     | 5/10 [00:01<00:01,  3.92it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 60%|██████    | 6/10 [00:01<00:00,  4.06it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)


 70%|███████   | 7/10 [00:01<00:00,  3.94it/s]

(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)
(16384, 16)
(16, 16384)
(262144, 1)
(131072, 2)
(65536, 4)
(32768, 8)


 80%|████████  | 8/10 [00:02<00:00,  3.71it/s]

(16384, 16)
(16, 16384)
(262144, 1)
(131072, 2)
(65536, 4)


 90%|█████████ | 9/10 [00:02<00:00,  3.76it/s]

(32768, 8)
(16384, 16)
(16, 16384)
(262144, 1)
(131072, 2)
(65536, 4)


100%|██████████| 10/10 [00:02<00:00,  3.90it/s]

(32768, 8)
(16384, 16)
(16, 16384)





In [None]:
graph.nodes.shape, graph.senders.shape, graph.receivers.shape

((1048576, 1), (3723360,), (3723360,))