In [None]:
import numpy as np
import pandas as pd
import os

import scipy.io as scio
from scipy import stats
import base_functions as bf
import conn_base_functions as cbf
import pickle

from nilearn import image, plotting, connectome
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# plt.rcParams["font.family"] = "Times New Roman"

%matplotlib inline

In [None]:
D = pickle.load(open('./data/D1Morning_sig_patterns.pkl', 'rb'))
pvals_fdr = D['pvals_fdr'].reshape(-1)
patterns = D['boot_patterns'].reshape(-1)
patterns_mtx = connectome.vec_to_sym_matrix(patterns, diagonal=np.zeros(442))

In [None]:
import nibabel as nib
from nilearn.image import load_img, smooth_img
from nilearn.input_data import NiftiMasker, NiftiLabelsMasker

mask_file = './data/whole_brain_mask_Sch7net400_subcortex_cerebellum_MNI152NLin2009cAsym_res-2space.nii.gz'

seed_masker = NiftiLabelsMasker(mask_file, standardize=False)
t = seed_masker.fit_transform(mask_file)

In [None]:
import seaborn as sns
from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from matplotlib import cm

In [None]:
rnfile = './data/whole_brain_mask_Sch7net400_subcortex_cerebellum.csv'
masker_info = pd.read_csv(rnfile)
network_name = masker_info['network_name'].to_list()
network_name = network_name[:442]

net_list = ['Vis', 'SomMot', 'DorsAttn', 'SalVentAttn', 'Limbic', 'Cont', 'Default', 'Subcortex', 'Cerebellum']#
reorder_idx = []
for inet,net in enumerate(net_list):
    idx = [i for i,n in enumerate(network_name) if n==net]
    if net=='Subcortex':# For subcortex, the regions started with right hemi, reorder it
        idx = idx[16:] + idx[:16]
    reorder_idx.extend(idx)

In [None]:
tick_size = 20
xylabel_size = 20
fig_size = (10,10)
use_colors = [(140/255, 140/255, 140/255), (45/255, 86/255, 98/255),  (188/255, 208/255, 217/255)]
threshold = 0.001
net_abbre = ['CON', 'DMN', 'DAN', 'LIM', 'VAN', 'SMN', 'VIS', 'SUB', 'CER']#
network_list = ['Cont','Default','DorsAttn','Limbic','SalVentAttn','SomMot','Vis','Subcortex', 'Cerebellum']

In [None]:
network_abbre_name = []
for i, net in enumerate(network_name):
    idx = network_list.index(net)
    network_abbre_name.append(net_abbre[idx])

In [None]:
def heatmap_mod(x, y, **kwargs):
    if 'color' in kwargs:
        color = kwargs['color']
    else:
        color = [1]*len(x)

    if 'palette' in kwargs:
        palette = kwargs['palette']
        n_colors = len(palette)
    else:
        n_colors = 100 # Use 256 colors for the diverging color palette
        palette = sns.color_palette("GnBu", n_colors) 

    if 'color_range' in kwargs:
        color_min, color_max = kwargs['color_range']
    else:
        color_min, color_max = min(color), max(color) # Range of values that will be mapped to the palette, i.e. min and max possible correlation

    def value_to_color(val):
        if color_min == color_max:
            return palette[-1]
        else:
            val_position = float((val - color_min)) / (color_max - color_min) # position of value in the input range, relative to the length of the input range
            val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
            ind = int(val_position * (n_colors - 1)) # target index in the color palette
            return palette[ind]

    if 'size' in kwargs:
        size = kwargs['size']
    else:
        size = [1]*len(x)

    if 'size_range' in kwargs:
        size_min, size_max = kwargs['size_range'][0], kwargs['size_range'][1]
    else:
        size_min, size_max = min(size), max(size)

    size_scale = kwargs.get('size_scale', 500)

    def value_to_size(val):
        if size_min == size_max:
            return 1 * size_scale
        else:
            val_position = (val - size_min) * 0.99 / (size_max - size_min) + 0.01 # position of value in the input range, relative to the length of the input range
            val_position = min(max(val_position, 0), 1) # bound the position betwen 0 and 1
            return val_position * size_scale
    if 'x_order' in kwargs: 
        x_names = [t for t in kwargs['x_order']]
    else:
        x_names = [t for t in sorted(set([v for v in x]))]
    x_to_num = {p[1]:p[0] for p in enumerate(x_names)}

    if 'y_order' in kwargs: 
        y_names = [t for t in kwargs['y_order']]
    else:
        y_names = [t for t in sorted(set([v for v in y]))]
    y_to_num = {p[1]:p[0] for p in enumerate(y_names)}

    fig = plt.figure(figsize=(7, 6.5))
    
    plot_grid = plt.GridSpec(1, 20, hspace=0.1, wspace=0.1) # Setup a 1x10 grid
    ax = plt.subplot(plot_grid[:,:-1]) # Use the left 14/15ths of the grid for the main plot

    marker = kwargs.get('marker', 's')

    kwargs_pass_on = {k:v for k,v in kwargs.items() if k not in [
         'color', 'palette', 'color_range', 'size', 'size_range', 'size_scale', 'marker', 'x_order', 'y_order', 'xlabel', 'ylabel'
    ]}

    # ax.scatter(
    #     x=[x_to_num[v] for v in x],
    #     y=[y_to_num[v] for v in y],
    #     marker=marker,
    #     s=[value_to_size(v) for v in size], 
    #     c=[value_to_color(v) for v in color],
    #     **kwargs_pass_on
    # )
    ## ignore elements with size 0
    for i,s in enumerate(size):
        if s==0:
            continue
        ax.scatter(x=[x_to_num[x[i]]],
                   y=[y_to_num[y[i]]],
                   marker=marker,
                   s=[value_to_size(size[i])], 
                   c=[value_to_color(color[i])],
                   **kwargs_pass_on)
        # if s>=size_max:
        #     ax.text(x_to_num[x[i]]-0.125, y_to_num[y[i]]-0.125, f'{np.round(s*100):.0f}', size=10, color='white')
        
    ax.set_xticks([v for k,v in x_to_num.items()])
    # ax.set_xticklabels([k for k in x_to_num], rotation=45, horizontalalignment='right', fontsize=15)
    ax.set_xticklabels([k for k in x_to_num], fontsize=15)
    ax.set_yticks([v for k,v in y_to_num.items()])
    ax.set_yticklabels([k for k in y_to_num], fontsize=15)


    # ax.grid(False, 'major')
    # ax.grid(True, 'minor')
    
    ax.set_xticks([t + 0.5 for t in ax.get_xticks()], minor=True)
    ax.set_yticks([t + 0.5 for t in ax.get_yticks()], minor=True)

    ax.set_xlim([-0.55, max([v for v in x_to_num.values()]) + 0.55])
    ax.set_ylim([-0.55, max([v for v in y_to_num.values()]) + 0.55])
    # ax.set_facecolor('#F1F1F1')

    ax.set_xlabel(kwargs.get('xlabel', ''))
    ax.set_ylabel(kwargs.get('ylabel', ''))
    
    ax.tick_params(which='both', width=0)
    
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.xaxis.set_tick_params(length=0, width=0, which='both')
    ax.yaxis.set_tick_params(length=0, width=0, which='both')
    for i in range(len(x_names)):
        for j in range(len(y_names)):
            if i>j:
                continue
            rect = Rectangle((i-0.5,len(y_names)-j-1.5), 1, 1, ec='gray', fc='none', linewidth=1., alpha=0.5)
            ax.add_patch(rect)

            rx, ry = rect.get_xy()
            cx = rx + rect.get_width()/2.0
            cy = ry + rect.get_height()/2.0

            ## add numbers to the matrix map
            # for iz,s in enumerate(size):
            #     if x_to_num[x[iz]]==i and y_to_num[y[iz]]==len(y_names)-j-1 and s>=size_max:
            #         # ax.annotate(f'{np.round(s*100):.0f}', (cx,cy), color='white', 
            #         #             fontsize=12, ha='center', va='center')
            #         ax.annotate(f'{np.round(s):.0f}', (cx,cy), color='white', 
            #                     fontsize=12, ha='center', va='center')


    
    # Add color legend on the right side of the plot
    if color_min < color_max:
        ax = plt.subplot(plot_grid[:,-1]) # Use the rightmost column of the plot

        col_x = [0]*len(palette) # Fixed x coordinate for the bars
        bar_y=np.linspace(color_min, color_max, n_colors) # y coordinates for each of the n_colors bars

        bar_height = bar_y[1] - bar_y[0]
        ax.barh(
            y=bar_y,
            width=[5]*len(palette), # Make bars 5 units wide
            left=col_x, # Make bars start at 0
            height=bar_height,
            color=palette,
            linewidth=0
        )
        ax.set_xlim(1, 2) # Bars are going from 0 to 5, so lets crop the plot somewhere in the middle
        ax.grid(False) # Hide grid
        # ax.set_facecolor('white') # Make background white
        ax.set_xticks([]) # Remove horizontal ticks
        ax.set_yticks(np.linspace(min(bar_y), max(bar_y), 3)) # Show vertical ticks for min, middle and max
        # ax.set_yticklabels([str(int(i*100)) for i in np.linspace(min(bar_y), max(bar_y), 3)])
        ax.set_yticklabels([str(int(i)) for i in np.linspace(min(bar_y), max(bar_y), 3)])
        
        ax.yaxis.tick_right() # Show vertical ticks on the right 
        ax.tick_params(length=0)
        ax.yaxis.set_label_position("right")
        # ax.set_ylabel('Percentage (%)')
        ax.set_ylabel('Edge number')
        
        # let colorbar to occupy 1/3 height of the figure
        cm_len = max(bar_y) - min(bar_y)
        ax.set_ylim([min(bar_y)-cm_len, max(bar_y)+cm_len])

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.xaxis.set_tick_params(length=0, width=0, which='minor')
    return fig

In [None]:
np.sum(np.logical_and(pvals_fdr<0.05, patterns<0))

In [None]:
pos_or_neg = 'neg'
idx = pvals_fdr<0.05
npt = np.zeros_like(patterns)
npt[idx] = 1
# zval_mtx = connectome.vec_to_sym_matrix(boot_zvals, diagonal=np.zeros(442))
sig_mtx = connectome.vec_to_sym_matrix(npt, diagonal=np.zeros(442))
if pos_or_neg == 'pos':
    sig_mtx[patterns_mtx<0] = 0
elif pos_or_neg == 'neg':
    sig_mtx[patterns_mtx>0] = 0
    
net_mtx, _ = cbf.get_net2net_matrix(sig_mtx, network_name, network_list, 'sum')
# net_mtx = net_mtx[:7][:,:7]

xv = []
yv = []
vals = []
for i in range(net_mtx.shape[0]):
    for j in range(net_mtx.shape[1]):
        if i>j:
            continue
        xv.append(net_abbre[i])
        yv.append(net_abbre[j])
        vals.append(net_mtx[i,j])
data = pd.DataFrame({'x':xv, 'y':yv, 'v':vals})

In [None]:
max7net = np.ceil(np.max(net_mtx[:7][:,:7])*100)/100
max7net = np.max(net_mtx[:7][:,:7])
max7net = np.max(net_mtx)
n_colors = 100 # Use 256 colors for the diverging color palette
if pos_or_neg == 'pos':
    palette = sns.color_palette("flare", n_colors) # crest and flare
elif pos_or_neg == 'neg':
    palette = sns.color_palette("crest", n_colors) # crest and flare

fig = heatmap_mod(data['x'], data['y'], size=data['v'], color=data['v'], palette=palette,
                  marker='o', size_range=[0,max7net], color_range=[0,max7net],
                  size_scale=500, x_order=net_abbre, y_order=net_abbre[::-1])
# fig.savefig(f'./results/net2net_num_{pos_or_neg}.pdf', bbox_inches='tight')

In [None]:
# chord plot
from nichord.chord import plot_chord
from nichord.convert import convert_matrix

chord_pos_or_neg = 'pos'
if chord_pos_or_neg == 'pos':
    idx = np.logical_and(pvals_fdr<0.05, patterns>0)
elif chord_pos_or_neg == 'neg':
    idx = np.logical_and(pvals_fdr<0.05, patterns<0)
    
npt = np.zeros_like(patterns.reshape(-1))
npt[idx] = patterns[idx]

sig_mtx = connectome.vec_to_sym_matrix(npt, diagonal=np.zeros(442))


idx_to_label = {i: network_abbre_name[i] for i in range(len(network_abbre_name))}
edges, edge_weights = convert_matrix(sig_mtx)

top_eidx = np.abs(edge_weights)>np.percentile(np.abs(edge_weights), 95)

apvals = pvals_fdr[idx]
alphas = np.ones_like(edge_weights)
alphas[~top_eidx] = 0.1

linewidths = 0.3*np.ones_like(edge_weights)

len(edge_weights)

In [None]:
coord_reg = pd.read_csv('./data/Sch7net400_subcortex_cerebellum_COG.txt', header=None, sep=' ').to_numpy()
tcolor = pd.read_csv('./data/colormap_8network.txt', header=None, sep=' ').to_numpy() / 255
net_colors = [tcolor[i] for i in range(tcolor.shape[0])]
net_colors[-1], net_colors[-2] = net_colors[-2], net_colors[-1]#switch order of VIS and SUB
net_colors.append(np.array([78/255, 240/255, 199/255]))
net_colors[3] = np.array([252/255,221/255,195/255])#change color for LIM, make it more visiable

net_abbre = ['CON', 'DMN', 'DAN', 'LIM', 'VAN', 'SMN', 'VIS', 'SUB', 'CER']#
network_list = ['Cont','Default','DorsAttn','Limbic','SalVentAttn','SomMot','Vis','Subcortex','Cerebellum']

network_abbre = [net_abbre[network_list.index(n)] for n in network_name]

region_colors = [net_colors[net_abbre.index(n)] for n in network_abbre]

In [None]:
if chord_pos_or_neg == 'pos':
    use_c = [182/255,75/255,103/255]
elif chord_pos_or_neg == 'neg':
    use_c = [50/255,128/255,140/255]
network_colors = dict(zip(net_abbre, net_colors))
plot_chord(idx_to_label, edges, edge_weights=[1 for e in edge_weights], arc_setting=True, vmin=-1, vmax=1, 
           fp_chord=f'./results/chord_{chord_pos_or_neg}.pdf', dpi=300,
           colors=[use_c for e in edge_weights], cbar=None, network_colors=network_colors,
           linewidths=[w for w in linewidths], alphas=[a for a in alphas], do_ROI_circles=False, label_fontsize=20)