In [None]:
import pickle
import numpy as np
import resting_state_summaries as rss
import matplotlib.pyplot as plt
from nilearn.image import load_img, new_img_like
import nilearn.plotting as plotting
import scipy

In [2]:
# normalize by each run
def normalize_timeseries(datas):
    for i in range(len(datas)):
        ts = datas[i]
        ts = (ts - np.expand_dims(ts.mean(axis=0), axis=0))/np.expand_dims(ts.std(axis=0), axis=0)
        datas[i] = ts
    return(datas)

In [3]:
K = 6
D = 10
with open('Final_model/K6_D10_500subjs_compact_model.pkl', 'rb') as f:
    [model, q, elbos, q_z] = pickle.load(f)

num_roi = model.N
num_subject = len(np.unique(model.tags))
pid = np.unique(model.tags)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
with open('data/roi_timeseries_rsfMRI_HCP_held_out', 'rb') as f:
    datas = pickle.load(f)
with open('data/tags_rsfMRI_HCP_held_out', 'rb') as f:
    tags = pickle.load(f)

In [5]:
def activity_at_late_time(y_bundles, late_times):
    K = len(y_bundles)
    num_roi = y_bundles[0][0].shape[1]
    activity_at_late_time = np.zeros((K, num_roi))
    for k in range(K):  
        for j in range(num_roi):
            t = late_times[k]
                # collect timeseries values from bundles that durates at least (t+1) time steps
            bold_t = []
            for i in range(len(y_bundles[k])):
                if len(y_bundles[k][i])>=(t+1):
                    bold_t.append(y_bundles[k][i][:,j][t])  # activity of jth ROI at time t of state entry
            # calculate mean responses
            activity_at_late_time[k,j] = np.mean(bold_t)
    return activity_at_late_time

In [6]:
C = model.parent.emissions.Cs[0]

In [7]:
with open('summary_data/state_duration.pkl','rb') as f:
   state_duration = pickle.load(f)

Group-level: state importance evolution

In [8]:
T = 30

In [9]:
z_subject = dict()
y_subject = dict()
z_bundle_subject = dict()
A_roi_subject = dict()
degree_subject = dict()
for s in range(num_subject):
    z_subject[s] = [q_z[i] for i in range(len(datas)) if tags[i] == pid[s]]
    y_subject[s] = [datas[i] for i in range(len(datas)) if tags[i] == pid[s]]
    z_bundle_subject[s] = rss.collect_z_bundle(z_subject[s])
    num_run = 4
    # use subject-level dynamics matrices
    A_roi = []
    for k in range(K):
        A_roi.append(C.dot(model.children[pid[s]].dynamics.As[k]).dot(C.T))
    
    degree = np.zeros((K, num_roi))
    for k in range(K):
        for j in range(num_roi):
            degree[k,j] = np.linalg.norm(A_roi[k][:,j])

    A_roi_subject[s] = A_roi
    degree_subject[s] = degree

In [None]:
z_subject = dict()
y_subject = dict()
z_bundle_subject = dict()
importance_timeseries_subject = dict()
importance_bundles_subject = dict()
for s in range(num_subject):
    # 1. importance timeseries for the subject
    z_subject[s] = [q_z[i] for i in range(len(datas)) if tags[i] == pid[s]]
    y_subject[s] = [datas[i] for i in range(len(datas)) if tags[i] == pid[s]]
    z_bundle_subject[s] = rss.collect_z_bundle(z_subject[s])
    num_run = 4
    importance_timeseries_subject[s] = []
    # use subject-level dynamics matrices
    A_roi = []
    for k in range(K):
        A_roi.append(C.dot(model.children[pid[s]].dynamics.As[k]).dot(C.T))
    degree = np.zeros((K, num_roi))
    for k in range(K):
        for j in range(num_roi):
            degree[k,j] = np.linalg.norm(A_roi[k][:,j])
    for rid in range(num_run):
        data_in_run = y_subject[s][rid]
        states_in_run = z_subject[s][rid]
        y_importance = np.multiply(np.abs(data_in_run), degree[states_in_run,:])
        importance_timeseries_subject[s].append(y_importance)
    # normalize
    importance_timeseries_subject[s] = normalize_timeseries(importance_timeseries_subject[s])
    
    # 2. compute subject-level importance evolution, analogous to activity evolution  
    importance_bundles_subject[s] = rss.collect_y_bundle(np.concatenate(importance_timeseries_subject[s]), z_bundle_subject[s], K)

In [None]:
# with open('summary_data/importance_timeseries.pkl','wb') as f:
#     pickle.dump(importance_timeseries, f)
# with open('summary_data/importance_timeseries_subject.pkl','wb') as f:
#     pickle.dump(importance_timeseries_subject, f)

In [None]:
# determine importance map at group level
importance_bundles = dict()
for k in range(K):
    importance_bundles[k] = []
    for s in range(num_subject):
        for j in range(len(importance_bundles_subject[s][k])):
            importance_bundles[k].append(importance_bundles_subject[s][k][j])
importance_map = activity_at_late_time(importance_bundles, state_duration + 3)

test the null of equal importance (permutation test)

In [None]:
importance_map_subject = dict()
for s in range(num_subject):
    importance_map_subject[s] = activity_at_late_time(importance_bundles_subject[s], state_duration + 3)

In [13]:
B = 1000
z_score = np.zeros((num_subject, K, num_roi))
for s in range(num_subject):
    null_maps = []
    for b in range(B):
        spatial_shuffle = np.arange(num_roi)
        np.random.shuffle(spatial_shuffle) 
        null_maps.append(importance_map_subject[s][:, spatial_shuffle])
    mu = np.mean(null_maps, axis = 0)
    sigma = np.std(null_maps, axis = 0, ddof=1)
    z_score[s,:,:] = (importance_map_subject[s] - mu) / sigma

In [None]:
pvalue_mat = np.zeros((K,num_roi))
mu_0 = 0
for k in range(K):
    for i in range(num_roi):
                scores = z_score[:,k,i]
                N = num_subject-np.sum(np.isnan(scores))
                test_statistics = np.sqrt(N)*(np.nanmean(scores)-mu_0)/np.nanstd(scores)
                df = N-1
                pvalue_mat[k,i] = scipy.stats.t.sf(test_statistics, df)

In [15]:
pvalues_corrected = scipy.stats.false_discovery_control(pvalue_mat)
rejection_mat = pvalues_corrected<0.05

In [None]:
# with open('summary_data/region_importance_rejection_mat.pkl','wb') as f:
#     pickle.dump(rejection_mat, f)

In [None]:
with open('summary_data/state_order.pkl','rb') as f:
   state_order = pickle.load(f)

importance map, cortex (thresholded)

In [20]:
atlas = load_img('Schaefer2018_200Parcels_17Networks_order_FSLMNI152_2mm.nii.gz')
def vec_to_img(vec):
    n_rois = 200
    atlas_data = atlas.get_fdata()

    vec_img_data = np.zeros_like(atlas_data)
    for idx_roi in range(n_rois):
        vec_img_data += (atlas_data==idx_roi+1) * vec[idx_roi]

    vec_img = new_img_like(data=vec_img_data, ref_niimg=atlas)
    return vec_img

In [24]:
from PIL import Image,ImageChops
def trim(im):
    bg = Image.new(im.mode, im.size, im.getpixel((0,0)))
    diff = ImageChops.difference(im, bg)
    diff = ImageChops.add(diff, diff, 2.0, -100)
    #Bounding box given as a 4-tuple defining the left, upper, right, and lower pixel coordinates.
    #If the image is completely empty, this method returns None.
    bbox = diff.getbbox()
    if bbox:
        return im.crop(bbox)
    
def join_2x2(imgs, out_path, dpi=300, padding=20):
    if len(imgs) != 4:
        raise ValueError("Need exactly 4 images in order: TL, TR, BL, BR")

    w1, h1 = imgs[0].size
    w2, h2 = imgs[1].size
    w3, h3 = imgs[2].size
    w4, h4 = imgs[3].size

    top_row_h = max(h1, h2)
    bottom_row_h = max(h3, h4)
    total_w = max(w1+w2,w3+w4) + padding
    total_h = top_row_h + bottom_row_h + padding

    grid = Image.new("RGBA", (total_w, total_h), (0, 0, 0, 0))
    grid.paste(imgs[0], (0, 0))
    grid.paste(imgs[1], (w1+padding, 0))
    grid.paste(imgs[2], (0, top_row_h+padding))
    grid.paste(imgs[3], (w3+padding, top_row_h+padding))

    grid.save(out_path, dpi=(dpi, dpi))

In [None]:
# thresholded maps
for k in range(K):
    state = state_order[k]
    vec = importance_map[state,:].copy()
    vec[rejection_mat[state, :] == False] = 0 
    vec_img = vec_to_img(vec[54:254])
    grid = []
    for view in ['lateral', 'medial']:
        for hemi in ['left', 'right']:
            fig,ax = plotting.plot_img_on_surf(
                        vec_img,
                        surf_mesh='fsaverage5', bg_on_data=True, inflate=True,
                        hemispheres=[hemi], views=[view],
                        vmin = -.3, vmax=0.3, 
                        threshold=1e-10,
                        colorbar=False,
                        cmap = 'seismic'
                    )
            fig.set_size_inches(4, 4)
            fig.savefig("maps/region_importance/importance_map_thresholded/state%s_%s_%s.png"%(k+1, hemi, view), dpi=300, transparent=True)
            plt.close(fig)
            grid.append(trim(Image.open("maps/region_importance/importance_map_thresholded/state%s_%s_%s.png"%(k+1, hemi, view))))
    join_2x2(grid, "maps/region_importance/importance_map_thresholded/state%s.png"%(k+1))

importance map, subcortex (thresholded)

In [None]:
# np.sum(rejection_mat[state_order, 0:54], axis = 1)

array([48, 11,  0,  2,  0,  0])

In [28]:
atlas = load_img('Schaefer2018_200Parcels_7Networks_order_Tian_Subcortex_S4_MNI152NLin6Asym_2mm.nii.gz')
def vec_to_img(vec):
    n_rois = 254
    atlas_data = atlas.get_fdata()

    vec_img_data = np.zeros_like(atlas_data)
    for idx_roi in range(n_rois):
        vec_img_data += (atlas_data==idx_roi+1) * vec[idx_roi]

    vec_img = new_img_like(data=vec_img_data, ref_niimg=atlas)
    return vec_img

In [29]:
for k in range(K):
    state = state_order[k]
    vec = importance_map[state,:].copy()
    vec[rejection_mat[state,:] == False]  = 0
    vec[54:254] = 0     # set entire cortex to zero, display subcortex only
    vec_img = vec_to_img(vec)
    vec_img.to_filename(f'maps/region_importance/importance_map_subcortex_thresholded/state%s.nii.gz'%(k+1))