In [None]:
import numpy as np
import os
import networkx as nx
from nilearn import datasets, plotting
from tqdm import tqdm
from nilearn.connectome import ConnectivityMeasure
import numpy as np
from jax import numpy as jnp
import jraph
import yaml
from src.plots import plot_graph, plot_regions
from src.utils import ROIS, DATA_DIR, SUBJECTS

## roi stuff

In [None]:
atlas = datasets.fetch_surf_fsaverage("fsaverage")

roi_class_to_roi = {"prf-visualrois": ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4"],
                    "floc-bodies": ["EBA", "FBA-1", "FBA-2", "mTL-bodies"],
                    "floc-faces": ["OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces"],
                    "floc-places": ["OPA", "PPA", "RSC"],
                    "floc-words": ["OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words"],
                    "streams": ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]}

roi_to_roi_class = {roi: roi_class for roi_class, rois in roi_class_to_roi.items() for roi in rois}

In [None]:
def load_roi_data(subject):
    """load the files for the given subject"""
    roi_dir = os.path.join(DATA_DIR, subject, "roi_masks")

    data = {'mapping' : {},
            'challenge' : {'lh' : {}, 'rh' : {}},
            'fsaverage' : {'lh' : {}, 'rh' : {}}}

    for roi_class in roi_class_to_roi.keys():
        data['mapping'][roi_class] = {'id_to_roi' : {}, 'roi_to_id' : {}}
        data['mapping'][roi_class]['id_to_roi'] = np.load(os.path.join(roi_dir, f'mapping_{roi_class}.npy'), allow_pickle=True).item()
        data['mapping'][roi_class]['roi_to_id'] = {v: k for k, v in data['mapping'][roi_class]['id_to_roi'].items()}
    
    for hem in ['lh', 'rh']:
        data['fsaverage'][hem]['all-vertices'] = np.load(os.path.join(roi_dir, f'{hem}.all-vertices_fsaverage_space.npy'))
        for roi_class in roi_class_to_roi.keys():
            data['challenge'][hem][roi_class] = np.load(os.path.join(roi_dir, f'{hem}.{roi_class}_challenge_space.npy'))
            data['fsaverage'][hem][roi_class] = np.load(os.path.join(roi_dir, f'{hem}.{roi_class}_fsaverage_space.npy'))

    return data

roi_data = {subject : load_roi_data(subject) for subject in SUBJECTS}

In [None]:
def load_fmri(subject):
    """load the fmri data for the given subject"""
    lh_fmri_file = os.path.join(DATA_DIR, subject, "training_split", "training_fmri", "lh_training_fmri.npy")
    rh_fmri_file = os.path.join(DATA_DIR, subject, "training_split", "training_fmri", "rh_training_fmri.npy")
    lh = np.load(lh_fmri_file)
    rh = np.load(rh_fmri_file)
    data = {'lh' : lh, 'rh' : rh}
    return data

fmri_data = {subject : load_fmri(subject) for subject in tqdm(SUBJECTS)}

In [None]:
def spaces(subject: str, hem: str, roi: str) -> np.ndarray:
    """return the fsaverage space mapping for the given subject and hemisphere"""
    if not roi:
        fsaverage_space = roi_data[subject]['fsaverage'][hem]['all-vertices']
        return fsaverage_space
    else:
        fsaverage_space = roi_data[subject]['fsaverage'][hem][roi_to_roi_class[roi]]
        challenge_space = roi_data[subject]['challenge'][hem][roi_to_roi_class[roi]]
        roi_id = roi_data[subject]['mapping'][roi_to_roi_class[roi]]['roi_to_id'][roi]
        fsaverage_space = np.asarray(fsaverage_space == roi_id, dtype=int)
        challenge_space = np.asarray(challenge_space == roi_id, dtype=int)
        return fsaverage_space, challenge_space


def fsaverage_vec(challenge_vec, subject, roi):
    """convert a challenge vector to fsaverage space"""
    hem = "lh" if challenge_vec.shape[0] == 19004 else "rh"  # r might have wrong dimensions
    if roi:
        fsaverage_space, challenge_space = spaces(subject, hem, roi)
        fsaverage_response = np.zeros(len(fsaverage_space))
        fsaverage_response[np.where(fsaverage_space)[0]] = \
        challenge_vec[np.where(challenge_space)[0]]
    else:
        fsaverage_space = spaces(subject, hem, roi)
        fsaverage_response = np.zeros(len(fsaverage_space))
        fsaverage_response[np.where(fsaverage_space)[0]] = challenge_vec
    return fsaverage_response


def plot_brain(challenge_vec, subject, hem, roi=None):
    """plot a vector on the brain"""
    surface = fsaverage_vec(challenge_vec, subject, roi)
    direction = "left" if hem == "lh" else "right"
    view = plotting.view_surf(
        surf_mesh=atlas["pial_" + direction],
        surf_map=surface,
        bg_map=atlas["sulc_" + direction],
        threshold=1e-14,
        cmap="twilight_shifted",
        colorbar=True,
        title=hem + " hemisphere " + subject,
        black_bg=True,
    )
    return view.resize(height=600, width=600)


vec = fmri_data['subj01']['lh'][0]
plot_brain(vec, "subj01", "lh", "V1v")


## fmri

In [None]:
import os
def subject_dir_files(subject):
    return [
        os.path.join(DATA_DIR, subject, "roi_masks", f)
        for f in sorted(os.listdir(os.path.join(DATA_DIR, subject, "roi_masks")))
        if f.startswith("mapping_")
    ]

## connectome

In [None]:
subject = "subj01"
lh_fmri, rh_fmri = get_fmri(subject)
# res = sum(map(lambda x: fsaverage_roi_response_to_image(subject, x, 0, "left"), rois))
plot_regions(subject, rois, 'left', 0)

In [None]:
def connectome_from_roi_response(subject, roi, hem):  # this is wrong
    roi_mask = get_roi_mask(subject, roi, hem, atlas="challenge")
    fmri = lh_fmri if hem == "left" else rh_fmri
    roi_response = fmri[:, roi_mask]
    connectivity_measure = ConnectivityMeasure(kind="covariance")
    connectivity_matrix = connectivity_measure.fit_transform([roi_response])[0]
    connectome = connectivity_matrix_to_connectome(connectivity_matrix)
    return connectome


def connectivity_matrix_to_connectome(connectivity_matrix):
    # given a connectivity matrix, return a graph
    N = connectivity_matrix.shape[0]
    thresh = np.percentile(
        np.abs(connectivity_matrix), 100 * (N - (N / 100)) / N
    )  # consider thresholding differently as n edges increases with nodes ** 2
    connectivity_matrix[np.abs(connectivity_matrix) < thresh] = 0
    # set diagonal to 0
    np.fill_diagonal(connectivity_matrix, 0)
    graph = nx.from_numpy_array(connectivity_matrix)
    return graph, connectivity_matrix

In [None]:
def build_connectome(roi, hem):
    nx_graph, _ = connectome_from_roi_response(roi, hem)
    # nx_graph = nx.karate_club_graph()  # TODO: replace with real graph
    nodes = jnp.eye(len(nx_graph.nodes))
    edges = None  # jnp.ones((len(G.edges), 1))
    senders = jnp.array([e[0] for e in list(nx_graph.edges)])
    receivers = jnp.array([e[1] for e in list(nx_graph.edges)])
    n_node = len(nodes)
    n_edge = len(senders)
    global_context = None  # jnp.array([1.0])
    graph = jraph.GraphsTuple(
        nodes=nodes,
        edges=edges,
        senders=senders,
        receivers=receivers,
        n_node=n_node,
        n_edge=n_edge,
        globals=global_context,
    )
    return graph


graph = build_connectome("V1d", "left")