In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from nilearn.image import load_img, new_img_like
import nilearn.plotting as plotting
import scipy

In [2]:
with open('summary_data/transition_importance_subject.pkl','rb') as f:
    transition_importance_subject = pickle.load(f)
importance_map = np.nanmean(transition_importance_subject, axis = 0)

In [3]:
with open('summary_data/transition_courses_importance.pkl','rb') as f:
    transition_courses = pickle.load(f)
# collect significant transitions
transitions_of_interest = []
K = 6
for k1 in range(K):
    for k2 in range(K):
        if len(transition_courses[k1][k2])>0:
            transitions_of_interest.append([k1, k2])
transitions_of_interest = np.array(transitions_of_interest)
num_transitions = transitions_of_interest.shape[0]

with open('summary_data/state_order.pkl','rb') as f:
   state_order = pickle.load(f)

In [4]:
num_subject = 500
num_roi = 254

permutation test for thresholding maps

In [5]:
B = 1000
z_score = np.zeros((num_subject, num_transitions, num_roi))
for s in range(num_subject):
    null_maps = []
    for b in range(B):
        spatial_shuffle = np.arange(num_roi)
        np.random.shuffle(spatial_shuffle) 
        null_maps.append(transition_importance_subject[s][:, spatial_shuffle])
    mu = np.mean(null_maps, axis = 0)
    sigma = np.std(null_maps, axis = 0, ddof=1)
    z_score[s,:,:] = (transition_importance_subject[s] - mu) / sigma

In [None]:
pvalue_mat = np.zeros((num_transitions,num_roi))
mu_0 = 0
for i in range(num_transitions):
    for j in range(num_roi):
                scores = z_score[:,i,j]
                N = num_subject-np.sum(np.isnan(scores))
                test_statistics = np.sqrt(N)*(np.nanmean(scores)-mu_0)/np.nanstd(scores)
                df = N-1
                pvalue_mat[i,j] = scipy.stats.t.sf(test_statistics, df)

In [7]:
pvalues_corrected = scipy.stats.false_discovery_control(pvalue_mat)
rejection_mat = pvalues_corrected<0.05

transition importance maps

In [27]:
atlas = load_img('Schaefer2018_200Parcels_17Networks_order_FSLMNI152_2mm.nii.gz')
def vec_to_img(vec):
    n_rois = 200
    atlas_data = atlas.get_fdata()

    vec_img_data = np.zeros_like(atlas_data)
    for idx_roi in range(n_rois):
        vec_img_data += (atlas_data==idx_roi+1) * vec[idx_roi]

    vec_img = new_img_like(data=vec_img_data, ref_niimg=atlas)
    return vec_img

In [30]:
from PIL import Image,ImageChops
def trim(im):
    bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
    diff = ImageChops.difference(im, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100)
    #Bounding box given as a 4-tuple defining the left, upper, right, and lower pixel coordinates.
    #If the image is completely empty, this method returns None.
    bbox = diff.getbbox()
    if bbox:
        return im.crop(bbox)
    
def join_2x2(imgs, out_path, dpi=300, padding=20):
    if len(imgs) != 4:
        raise ValueError("Need exactly 4 images in order: TL, TR, BL, BR")

    w1, h1 = imgs[0].size
    w2, h2 = imgs[1].size
    w3, h3 = imgs[2].size
    w4, h4 = imgs[3].size

    top_row_h = max(h1, h2)
    bottom_row_h = max(h3, h4)
    total_w = max(w1+w2,w3+w4) + padding
    total_h = top_row_h + bottom_row_h + padding

    grid = Image.new("RGBA", (total_w, total_h), (0, 0, 0, 0))
    grid.paste(imgs[0], (0, 0))
    grid.paste(imgs[1], (w1+padding, 0))
    grid.paste(imgs[2], (0, top_row_h+padding))
    grid.paste(imgs[3], (w3+padding, top_row_h+padding))

    grid.save(out_path, dpi=(dpi, dpi))

In [39]:
for i in range(num_transitions):
    vec = importance_map[i,:].copy()
    vec[rejection_mat[i, :] == False] = 0 
    vec_img = vec_to_img(vec[54:254])
    grid = []
    for view in ['lateral', 'medial']:
        for hemi in ['left', 'right']:
            fig,ax = plotting.plot_img_on_surf(
                        vec_img,
                        surf_mesh='fsaverage5', bg_on_data=True, inflate=True,
                        hemispheres=[hemi], views=[view],
                        vmin = -0.01, vmax=0.01, 
                        threshold=1e-10,
                        colorbar=False,
                        cmap = 'seismic'
                    )
            fig.set_size_inches(4, 4)
            fig.savefig("maps/region_importance/transition_importance/state%s-to-state%s_%s_%s.png"%(transitions_of_interest[i,0] + 1, transitions_of_interest[i,1] + 1, hemi, view), dpi=300, transparent=True)
            plt.close(fig)
            grid.append(trim(Image.open("maps/region_importance/transition_importance/state%s-to-state%s_%s_%s.png"%(transitions_of_interest[i,0] + 1, transitions_of_interest[i,1] + 1, hemi, view))))
    join_2x2(grid, "maps/region_importance/transition_importance/state%s-to-state%s.png"%(transitions_of_interest[i,0] + 1, transitions_of_interest[i,1] + 1))