In [5]:
import numpy as np
import networkx as nx
from nilearn import datasets
from tqdm import tqdm
from nilearn.connectome import ConnectivityMeasure
import numpy as np
from jax import numpy as jnp
import jraph
import yaml
from src.plots import plot_graph
from src.fmri import get_roi_mask, lh_fmri, rh_fmri

## roi stuff

In [6]:
rois = [
    "V1v",
    "V1d",
    "V2v",
    "V2d",
    "V3v",
    "V3d",
    "hV4",
    "EBA",
    "FBA-1",
    "FBA-2",
    "mTL-bodies",
    "OFA",
    "FFA-1",
    "FFA-2",
    "mTL-faces",
    "aTL-faces",
    "OPA",
    "PPA",
    "RSC",
    "OWFA",
    "VWFA-1",
    "VWFA-2",
    "mfs-words",
    "mTL-words",
    "early",
    "midventral",
    "midlateral",
    "midparietal",
    "ventral",
    "lateral",
    "parietal",
]
atlas = datasets.fetch_surf_fsaverage("fsaverage")

In [9]:
def get_roi_size(roi, hem):
    """get roi size"""
    roi_mask = get_roi_mask(roi, hem)
    return int(np.sum(roi_mask))

roi_sizes = {'subj05' : {'left_hem': {}, 'right_hem': {}}}

for roi in rois:
    roi_sizes['subj05']['left_hem'][roi] = get_roi_size(roi, 'left_hem')
    roi_sizes['subj05']['right_hem'][roi] = get_roi_size(roi, 'right_hem')

# save to yaml file
with open('config/rois.yaml', 'w') as f:
    yaml.dump(roi_sizes, f)

    

## connectome

In [3]:
def connectome_from_roi_response(roi, hem):  # this is wrong
    roi_mask = get_roi_mask(roi, hem, atlas="challenge")
    fmri = lh_fmri if hem == "left" else rh_fmri
    roi_response = fmri[:, roi_mask]
    connectivity_measure = ConnectivityMeasure(kind="covariance")
    connectivity_matrix = connectivity_measure.fit_transform([roi_response])[0]
    connectome = connectivity_matrix_to_connectome(connectivity_matrix)
    return connectome


def connectivity_matrix_to_connectome(connectivity_matrix):
    # given a connectivity matrix, return a graph
    N = connectivity_matrix.shape[0]
    thresh = np.percentile(
        np.abs(connectivity_matrix), 100 * (N - (N / 100)) / N
    )  # consider thresholding differently as n edges increases with nodes ** 2
    connectivity_matrix[np.abs(connectivity_matrix) < thresh] = 0
    # set diagonal to 0
    np.fill_diagonal(connectivity_matrix, 0)
    graph = nx.from_numpy_array(connectivity_matrix)
    return graph, connectivity_matrix

In [7]:
def build_connectome(roi, hem):
    nx_graph, _ = connectome_from_roi_response(roi, hem)
    # nx_graph = nx.karate_club_graph()  # TODO: replace with real graph
    nodes = jnp.eye(len(nx_graph.nodes))
    edges = None  # jnp.ones((len(G.edges), 1))
    senders = jnp.array([e[0] for e in list(nx_graph.edges)])
    receivers = jnp.array([e[1] for e in list(nx_graph.edges)])
    n_node = len(nodes)
    n_edge = len(senders)
    global_context = None  # jnp.array([1.0])
    graph = jraph.GraphsTuple(
        nodes=nodes,
        edges=edges,
        senders=senders,
        receivers=receivers,
        n_node=n_node,
        n_edge=n_edge,
        globals=global_context,
    )
    return graph


graph = build_connectome("V1d", "left")