In [2]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import resting_state_summaries as rss
from nilearn.image import load_img, new_img_like
import nilearn.plotting as plotting
import os

atlas = load_img('Schaefer2018_200Parcels_17Networks_order_FSLMNI152_2mm.nii.gz')
def vec_to_img(vec):
    n_rois = 200
    atlas_data = atlas.get_fdata()

    vec_img_data = np.zeros_like(atlas_data)
    for idx_roi in range(n_rois):
        vec_img_data += (atlas_data==idx_roi+1) * vec[idx_roi]

    vec_img = new_img_like(data=vec_img_data, ref_niimg=atlas)
    return vec_img

In [None]:
# 8-network, subcortex
subcortex_groups = [[1,2,5,6,7,28,29,32,33,34], 
                    [3,4,10,11,12,30,31,37,38,39], 
                    [8,9,23,35,36,50],
                    [13,14,15,16,40,41,42,43],
                    [17,18,19,20,44,45,46,47],
                    [24,25,51,52],
                    [21,22,48,49],
                    [26,27,53,54]]

for g in range(len(subcortex_groups)):
    subcortex_groups[g] =  np.array(subcortex_groups[g], dtype = int)-1

subcortex_key_labels = ['Hip','aTha','pTha','Put','Cau','Nac','Amy','Gp']
# 17 network, cortex
cortex_groups = [[1,2,3,4,5,6, 101, 102, 103, 104, 105, 106], 
                 [7,8,9,10,11,12,107, 108, 109, 110, 111, 112], 
                 [13,14,15,16,17,18,19,20,113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123], 
                 [21, 22, 23, 24, 25, 26, 27, 28,124, 125, 126, 127, 128, 129, 130],
                 [29, 30, 31, 32, 33, 34,131, 132, 133, 134, 135, 136],
                 [35, 36, 37, 38, 39,137, 138, 139, 140, 141],
                 [40, 41, 42, 43, 44, 45, 46,142, 143, 144, 145, 146, 147, 148, 149, 150],
                 [47, 48, 49, 50,151, 152, 153, 154, 155, 156],
                 [51,52,157, 158, 159, 160],
                 [53, 54, 55, 56,161, 162, 163, 164],
                 [57, 58, 59, 60, 61, 62, 63, 64, 65, 66,165, 166, 167, 168, 169, 170],
                 [67, 68, 69, 70, 71,171, 172, 173, 174, 175, 176, 177, 178, 179, 180],
                 [72, 73, 74,181, 182, 183],
                 [75, 76, 77, 78, 79, 80, 81, 82,184, 185, 186, 187, 188, 189], 
                 [83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,190, 191, 192, 193], 
                 [96, 97, 98,194, 195, 196],
                 [99,100,197, 198, 199, 200]]
for g in range(len(cortex_groups)):
    cortex_groups[g] = [cortex_groups[g][i]+54 for i in range(len(cortex_groups[g]))]
cortex_key_labels = ['VisCent','VisPeri','SomMotA','SomMotB','DorsAttnA','DorsAttnB', 'SalVentAttnA',
             'SalVentAttnB','LimbicB','LimbicA','ContA','ContB','ContC','DefaultA','DefaultB',
             'DefaultC','TempPar']
cortex_keys = np.cumsum([len(cortex_groups[g]) for g in range(len(cortex_groups))])[0:16]
cortex_keys = np.append(0, cortex_keys)
for g in range(len(cortex_groups)):
    cortex_groups[g] =  np.array(cortex_groups[g], dtype = int)-1
groups = subcortex_groups+cortex_groups
key_labels = subcortex_key_labels+cortex_key_labels
csum = np.cumsum([len(groups[g]) for g in range(len(groups))])
keys = csum[0:(len(csum)-1)]
keys = np.append(0, keys)
pi = np.concatenate(groups)

In [4]:
K = 6
D = 10
with open('Final_model/K6_D10_500subjs_compact_model.pkl', 'rb') as f:
    [model, q, elbos, q_z] = pickle.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
num_roi = model.N
num_subject = len(np.unique(model.tags))
pid = np.unique(model.tags)

In [4]:
with open('data/roi_timeseries_rsfMRI_HCP_held_out', 'rb') as f:
    datas = pickle.load(f)
with open('data/tags_rsfMRI_HCP_held_out', 'rb') as f:
    tags = pickle.load(f)
datas_fitted = [datas[i] for i in range(len(datas)) if tags[i] in pid]  
tags_fitted = list(np.concatenate([np.repeat(pid[l], 4) for l in range(len(pid))]))
ys = np.concatenate(datas_fitted)

In [None]:
y_subject = dict()
z_subject = dict()
for s in range(num_subject):
    y_subject[s] = np.concatenate([datas[i] for i in range(len(datas)) if tags[i] == pid[s]])  
    z_subject[s] = [q_z[i] for i in range(len(datas)) if tags[i] == pid[s]]
z_bundle_subject = dict()
y_bundle_subject = dict()
for s in range(num_subject):
    z_bundle_subject[s] = rss.collect_z_bundle(z_subject[s])
    y_bundle_subject[s] = rss.collect_y_bundle(y_subject[s], z_bundle_subject[s], K)
z_bundles = [z_bundle_subject[i][j] for i in range(num_subject) for j in range(len(z_bundle_subject[i]))]

set a common transition window (TR = 0.72)

In [13]:
window_length = 20
time_to_transition = np.array(window_length/2, dtype = int)

In [None]:
# transition course between two states
def activity_evolution_in_transition(ys, z_bundles, k1, k2, window_length = 20):
    activity_in_transition = dict()
    y_bundle_from_state = []
    y_bundle_to_state = []
    for i in range(1,len(z_bundles)):
        if (z_bundles[i-1][0] == k1) & (z_bundles[i][0] == k2):
            time_knots = np.cumsum([len(z_bundles[j]) for j in range(i)])
            current_time = time_knots[-1]
            y_bundle_from_state.append(ys[(current_time-len(z_bundles[i-1])):current_time,:])
            y_bundle_to_state.append(ys[current_time:(current_time+len(z_bundles[i])),:])
    for j in range(num_roi):
        activity_in_transition[j] = np.zeros(window_length)
        for t in range(time_to_transition): 
        # e.g., if window_length = 10, then t=0 should be 5TR before state transition; 
        # similarly, t=4 is 1TR before state transition 
            bold_t = []
            for i in range(len(y_bundle_from_state)):
                if len(y_bundle_from_state[i])>time_to_transition-t:
                    bold_t.append(y_bundle_from_state[i][:,j][len(y_bundle_from_state[i])-time_to_transition+t])
            activity_in_transition[j][t] = np.mean(bold_t)
        for t in range(time_to_transition, window_length):
            bold_t = []
            for i in range(len(y_bundle_to_state)):
                if len(y_bundle_to_state[i])>t-time_to_transition:
                    bold_t.append(y_bundle_to_state[i][:,j][t-time_to_transition])
            activity_in_transition[j][t] = np.mean(bold_t)
    return activity_in_transition

In [None]:
transition_mat_rejection = np.array([[0,0,0,1,1,0], [1,0,0,0,0,0], [0,1,0,1,0,0], [1,0,0,0,0,0,], [1,0,1,0,0,0], [0,1,0,0,1,0]])
# or import from transtion significance test 

In [None]:
transition_courses = dict()
transition_courses_subject = dict()
for s in range(num_subject):
    transition_courses_subject[s] = dict()
for k1 in range(K):
    transition_courses[k1] = dict()
    for s in range(num_subject):
        transition_courses_subject[s][k1] = dict()
    for k2 in range(K):
        transition_courses[k1][k2] = dict()
        for s in range(num_subject):
            transition_courses_subject[s][k1][k2] = dict()

In [None]:
with open('summary_data/state_order.pkl','rb') as f:
    state_order = pickle.load(f)

In [None]:
for k1 in range(K):
    for k2 in range(K):
        if transition_mat_rejection[k1,k2] == 1:
            state1 = state_order[k1]
            state2 = state_order[k2]
            activity_evolution = activity_evolution_in_transition(ys, z_bundles, state1, state2, window_length = 20) 
            activity_evolution_subject = [activity_evolution_in_transition(y_subject[s], z_bundle_subject[s], state1, state2, window_length = 20)  for s in range(num_subject)]

            network_evolution = [np.mean([activity_evolution[j] for j in groups[i]], axis = 0) for i in range(len(groups))]
            transition_courses[k1][k2] = network_evolution

            network_evolution_subject = dict()          
            for s in range(num_subject):
                network_evolution_subject[s] = [np.mean([activity_evolution_subject[s][j] for j in groups[i]], axis = 0) for i in range(len(groups))]
                transition_courses_subject[s][k1][k2] = network_evolution_subject[s]

In [None]:
# with open('summary_data/transition_courses.pkl','wb') as f:
#     pickle.dump(transition_courses, f)
# with open('summary_data/transition_courses_subject.pkl','wb') as f:
#     pickle.dump(transition_courses_subject, f)