# Figures

In [None]:
import os
import numpy as np
import pandas as pd

import nibabel as nib
from nilearn import plotting
from nilearn.image import math_img
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

from neuromaps.datasets import fetch_fslr
from neuromaps.transforms import mni152_to_fslr
from surfplot import Plot
from surfplot.utils import threshold

from similarity import calculate_voxel_similarity
from matplotlib_venn import venn2

gm_mask = os.path.abspath('../pyALE/utils/mask/Grey10.nii')
bg_img = nib.load('/usr/local/fsl/data/standard/MNI152_T1_1mm_brain.nii.gz')

#cmap = ListedColormap(['#0200F5', '#EA33F7', '#EA3324', '#EA3324', '#EA3324', '#EA33F7', '#0200F5'])
cmap = ListedColormap(['#0200F5', '#EA33F7', '#EA3324'])
cmap2 = ListedColormap(['#EA3324', '#EA33F7', '#0200F5', '#0200F5', '#0200F5', '#EA33F7', '#EA3324'])

def olp_and_conj(task, img1, img2, thresh=0.5):
    img1_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img1}.nii.gz')
    img2_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img2}.nii.gz')
    img1_ = math_img(f'(img > {thresh}).astype(bool)', img=img1_img)
    img2_ = math_img(f'(img > {thresh}).astype(bool)', img=img2_img)
    conj_ = math_img('(img1 + img2) > 1.', img1=img1_, img2=img2_)
    olp_conj_img = math_img('(img1-img3)*1. + (img2-img3)*3. + img3*2.', img1=img1_, img2=img2_, img3=conj_)
    #olp_conj_img = math_img('(img1-img3).astype(bool)*-1. + (img2-img3).astype(bool)*1. + img3.astype(bool)*.1', img1=img1_, img2=img2_, img3=conj_)
    return olp_conj_img

def overlap(task, img1, img2, thresh=0.5):
    img1_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img1}.nii.gz')
    img2_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img2}.nii.gz')
    img1_ = math_img(f'(img > {thresh}).astype(bool)', img=img1_img)
    img2_ = math_img(f'(img > {thresh}).astype(bool)', img=img2_img)
    overlap_img = math_img(f'(img1 + img2) > 1.', img1=img1_, img2=img2_)
    return overlap_img

def conjunction(task, img1, img2, thresh=0.5):
    img1_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img1}.nii.gz')
    img2_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img2}.nii.gz')
    conj_img = math_img(f'(img1 > {thresh}).astype(bool) + (img2 > {thresh}).astype(bool)', img1=img1_img, img2=img2_img)
    return conj_img

In [None]:
# main effects - like in manuscript
tasks = ['n-back', 'n-back', 'stroop', 'emo-faces']
img1s = ['2BKvsBASE--0BKvsBASE_P95', '2BKvs0BK--1BKvs0BK_P95', 'IvsNRB--CvsNRB_P95', 'EMOvsBASE--NEUvsBASE_P95']
img2s = ['2BKvs0BK_cFWE05', '2BKvs1BK_cFWE05', 'IvsC_cFWE05', 'EMOvsNEU_cFWE05']

z_cut_coords = [[-32,-4,16,28,46,56], [-30,-24,2,12,26,46,56], [-28,-6,3,20,35,45,54], [-30,-20,-12,-2,4,30,48]]

#colours = ['blue', 'magenta', 'red'] # '#0200F5', '#EA33F7', '#EA3324'
colours = ['#35B779', '#FDE725', '#482878']

fig_folder = '_figures'

In [None]:
# supplement - matched contrasts
tasks = ['n-back', 'n-back', 'stroop', 'emo-faces']
img1s = ['2BKvsBASE_matched_2vs0--0BKvsBASE_matched_2vs0_P95', '2BKvs0BK_matched_2vs1--1BKvs0BK_matched_2vs1_P95',
         'IvsNRB_matched--CvsNRB_matched_P95', 'EMOvsBASE_matched--NEUvsBASE_matched_P95']
img2s = ['2BKvs0BK_matched_2vs0_cFWE05', '2BKvs1BK_matched_2vs1_cFWE05', 'IvsC_matched_cFWE05', 'EMOvsNEU_matched_cFWE05']

z_cut_coords = [[-32,-4,12,28,46,56], [-32,-24,2,12,26,46,56], [-28,-6,3,20,35,45,54], [-30,-20,-2,4,30,48]]

colours = ['blue', 'magenta', 'red'] # '#0200F5', '#EA33F7', '#EA3324'
colours = ['#35B779', '#FDE725', '#482878']

fig_folder = '_supplement_figures'

In [None]:
tasks = ['n-back', 'n-back', 'stroop', 'emo-faces']
img1s = ['2BKvsBASE_cFWE05', '2BKvs0BK_cFWE05', 'IvsNRB_cFWE05', 'EMOvsBASE_cFWE05']
img2s = ['0BKvsBASE_cFWE05', '1BKvs0BK_cFWE05', 'CvsNRB_cFWE05', 'NEUvsBASE_cFWE05']

z_cut_coords = [[-32,-4,12,28,46,56], [-32,-24,2,12,26,46,56], [-28,-6,3,20,35,45,54], [-30,-20,-12,-2,4,30,48]]

colours = ['#482878', '#35B779', '#FDE725'] # '#91bfdb', '#ffffbf', '#fc8d59'
fig_folder = '_supplement_figures'

In [None]:
surfaces = fetch_fslr()
lh, rh = surfaces['inflated']

cmap = ListedColormap(colours)
cmap2 = ListedColormap(['#FFFFFF','#FFFFFF','#FFFFFF','#FFFFFF'] + colours)

for task, img1, img2, cut_coords in zip(tasks, img1s, img2s, z_cut_coords):
    plt.rcParams.update(plt.rcParamsDefault)

    # create output dir
    output_dir = os.path.abspath(f'../../output/{fig_folder}/{task}/comp_{img1}_{img2}')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Z slices plot
    olp_conj_img = olp_and_conj(task, img1, img2)
    disp = plotting.plot_stat_map(
        olp_conj_img, display_mode='z',
        #title=f'{img1} vs. {img2}',
        draw_cross=False,
        cmap=cmap2,
        vmax=3,
        cut_coords=cut_coords,
        bg_img=bg_img,
        black_bg=False
    )
    disp.savefig(os.path.join(output_dir, 'z_slices.png'), dpi=200)
    plt.close()

    # Surface plot
    p1 = Plot(surf_lh=lh, surf_rh=rh, size=(1600, 400), zoom=1.2, layout='row', mirror_views=True, brightness=.6)
    p2 = Plot(surf_lh=lh, surf_rh=rh, brightness=.6)

    for p, name in zip([p1, p2], ['_row', '']):
        olp = overlap(task, img1, img2)
        img1_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img1}.nii.gz')
        img2_img = os.path.abspath(f'../../output/{task}/evaluation/output_clusterize/{img2}.nii.gz')
        img1_img = math_img('img > 0.5', img=img1_img)
        img2_img = math_img('img > 0.5', img=img2_img)

        gii_lh, gii_rh = mni152_to_fslr(olp, '32k')
        img_lh = threshold(gii_lh.agg_data(), .1)
        img_rh = threshold(gii_rh.agg_data(), .1)

        gii2_lh, gii2_rh = mni152_to_fslr(img2_img, '32k')
        img2_lh = threshold(gii2_lh.agg_data(), .1)
        img2_rh = threshold(gii2_rh.agg_data(), .1)

        gii1_lh, gii1_rh = mni152_to_fslr(img1_img, '32k')
        img1_lh = threshold(gii1_lh.agg_data(), .1)
        img1_rh = threshold(gii1_rh.agg_data(), .1)
        
        #p.add_layer({'left': lh, 'right': rh}, cmap='binary_r', cbar=False, alpha=.5)

        p.add_layer(
            {'left': img2_lh, 'right': img2_rh},
            cmap=ListedColormap([colours[2]]), cbar=False)
        p.add_layer(
            {'left': img1_lh, 'right': img1_rh},
            cmap=ListedColormap([colours[0]]), cbar=False)
        p.add_layer(
            {'left': img_lh, 'right': img_rh},
            cmap=ListedColormap([colours[1]]), cbar=False)

        fig = p.build()
        fig.savefig(os.path.join(output_dir, f'surface_projection{name}.png'), dpi=200)
        plt.close()
     
    # Venn diagram plots
    # create venn output dir
    voutput_dir = os.path.join(output_dir, 'venn')
    if not os.path.exists(voutput_dir):
        os.makedirs(voutput_dir)

    vox_comp = calculate_voxel_similarity(img1=img1_img, img2=img2_img, gm_mask=gm_mask, thr_img1=.1, thr_img2=.1)
    print(vox_comp)

    subsets_vox = (vox_comp['voxel_map1']-vox_comp['voxel_overlap'],
               vox_comp['voxel_map2']-vox_comp['voxel_overlap'],
               vox_comp['voxel_overlap'])

    plt.rcParams.update({'font.size': 24})
    v = venn2(
        subsets_vox,
        set_labels = ('', ''),
        #set_labels = ('CMeta-voxels', 'MCexp-voxels'),
        #set_labels = (f'{img1}', f'{img2}'),
        #subset_label_formatter = lambda v: '{:.2%}'.format(v/sum(subsets_vox)),
        # percentage of total rounded to 1 decimal
        subset_label_formatter=lambda x: str(x) + "\n(" + f"{(x/sum(subsets_vox)):1.0%}" + ")",
        normalize_to = 1.0)
    v.get_patch_by_id('10').set_color(colours[0])
    v.get_patch_by_id('10').set_edgecolor('none')
    v.get_patch_by_id('10').set_alpha(0.7)
    #v.get_label_by_id('10').set_text(str(vox_comp['voxel_map1']-vox_comp['voxel_overlap']))
    v.get_patch_by_id('01').set_color(colours[2])
    v.get_patch_by_id('01').set_edgecolor('none')
    v.get_patch_by_id('01').set_alpha(0.7)
    #v.get_label_by_id('01').set_text(str(vox_comp['voxel_map2']-vox_comp['voxel_overlap']))
    if subsets_vox[2] != 0:
        v.get_patch_by_id('11').set_color(colours[1])
        v.get_patch_by_id('11').set_alpha(0.7)
        v.get_patch_by_id('11').set_edgecolor('none')
        #v.get_label_by_id('11').set_text(str(vox_comp['voxel_overlap']))

    plt.savefig(os.path.join(voutput_dir, 'voxel.svg'), dpi=200)
    plt.close()

    nib.save(olp_conj_img, os.path.join(output_dir, f'{img1}-1_AND-2_{img2}-3.nii.gz'))


In [None]:
from nilearn.reporting import get_clusters_table
