In [None]:
%run 2-source.ipynb

In [None]:
get_all_regions = False
if get_all_regions:
    rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package)
    all_tcs = get_all_tcs(output_dir_non_baseline_non_average,overwrite=False)
    all_tcs['hemi'] = all_tcs.source_region.apply(lambda x: x.split('-')[1])
    all_tcs['source_region'] = all_tcs.source_region.apply(lambda x: rel_mappings[x])

significant_regions = ['dlpfc','dorsal_anterior_cingulate','ventral_anterior_cingulate','parastriate','striate']
rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package,regions_in_activity=significant_regions) # for all regions use 'all'
muted_pal = sns.color_palette("cubehelix",n_colors=len(rel_labels)) # set a palette
new_col = muted_pal.pop()
new_cols = []
for t in range(len(rel_labels)):
    if t>0:
        if rel_labels[t] == rel_labels[t-1]:
            new_cols.append(new_col)
        else:
            new_col = muted_pal.pop()
            new_cols.append(new_col)
    else:
        new_cols.append(new_col)
# note that new_cols is not used just to manually set color
ordered_input_tcs,all_tcs_tcs = get_epoched_tcs(all_tcs[all_tcs.source_region.isin(significant_regions)], rel_labels, rel_mappings)

def get_readable_labels(rel_labels):
    rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package,regions_in_activity=significant_regions) # for all regions use 'all'
    rel_labels_color_map = {'dlpfc': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
    'dorsal_anterior_cingulate': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
    'ventral_anterior_cingulate': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
    'striate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
    'parastriate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354),
    'DLPFC': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
    'dACC': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
    'vACC': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
    'V1': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
    'V2': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354)}
    for idx,l in enumerate(rel_labels):
        l.name = f"{rel_mappings[l.name]}-{l.hemi}"
        for r in rel_labels_color_map:
            if r.lower() == l.name.split('-')[0]:
                l.color = rel_labels_color_map[r]
    return rel_labels
rel_labels = get_readable_labels(rel_labels)

from collections import OrderedDict
rename_dict = OrderedDict({'index':'measure','Abs_Steer_Wheel_Degree':'Motor intensity (deg)', 'density': 'Trial opacity',
               'NSLR_count_Saccade': 'Saccade count',
               'NSLR_mean_duration_Saccade': 'Saccade duration (ms)',
               '_rh_4-8_Hz_Power': ' Theta RH',
              '_rh_8-15_Hz_Power': ' Alpha RH',
               '_rh_15-32_Hz_Power': ' Beta RH',
               '_rh_32-55_Hz_Power': ' Gamma RH',
               '_lh_4-8_Hz_Power': ' Theta LH',
              '_lh_8-15_Hz_Power': ' Alpha LH',
               '_lh_15-32_Hz_Power': ' Beta LH',
               '_lh_32-55_Hz_Power': ' Gamma LH',
               'dorsal_anterior_cingulate': 'dACC',
               'ventral_anterior_cingulate': 'vACC',
               'dlpfc': 'DLPFC',
               'parastriate': 'V2',
                'striate': 'V1',
                'intermediate_frontal': 'MFG',
               '_': ' ',
              'bpm':'BPM', 'rmssd':'RMSSD (ms)','pnn50':'PNN50','Left Pupil Diameter':'Pupil Diameter (mm)'
              })

# Decoding source space data

In [None]:
from mne.decoding import (LinearModel, SlidingEstimator, cross_val_multiscore,
                          get_coef)
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import make_scorer, mean_squared_error
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

In [None]:
all_tcs_tcs['ppid'] = all_tcs_tcs['pid']
decoding_df = all_tcs_tcs.merge(motor_epochs.metadata[['ppid', 'session', 'trial_start_time','Abs_Steer_Wheel_Degree','Steer_Wheel_Degree_Categorical']], on=['ppid', 'session', 'trial_start_time'], 
                   how='inner', indicator=True)
all_trials_tcs = np.empty((len(decoding_df.groupby(by=['pid','session','trial_start_time'])), len(decoding_df.groupby(by=['source_region','hemi'])),len(decoding_df.groupby(by=['time']))))
ys = []
trial_no=0
for key, grouped_df in decoding_df.groupby(by=['pid','session','trial_start_time']):
    region_specific=[]
    for sub_key, sub_grouped_df in grouped_df.groupby(by=['source_region','hemi']):
        region_specific.append(sub_grouped_df['baseline_corr_activation'])
    all_trials_tcs[trial_no,:,:] = np.array(region_specific)
    ys.append(grouped_df.Abs_Steer_Wheel_Degree.iloc[0])
    trial_no += 1
ys = np.array(ys)

In [None]:
ordered_features = list(grouped_df.groupby(by=['source_region','hemi']).size().keys())

In [None]:
# Retrieve source space data into an array
num_folds = 100
all_feats_decod = []
for i in range(all_trials_tcs.shape[1]): # loop through dims
    X = all_trials_tcs[:,i:i+1,:]
    y = ys

    # prepare a series of classifier applied at each time sample
    clf = make_pipeline(StandardScaler(),  # z-score normalization
                        LinearModel(LinearRegression()))
    time_decod = SlidingEstimator(clf, scoring='r2')

    # Run cross-validated decoding analyses:
    scores = cross_val_multiscore(time_decod, X, y, cv=num_folds, n_jobs=None,verbose=False)
    scores = np.abs(scores) # convert back to MSE
    scores_df = pd.DataFrame(scores,index=np.arange(num_folds)+1).T
    samples = list(scores_df.index)
    scores_df = pd.melt(scores_df,var_name='fold',value_name='r2',ignore_index=False)
    scores_df['sample_no'] = samples*num_folds
    scores_df['source_region'] = ordered_features[i][0]
    scores_df['hemi'] = ordered_features[i][1]
    all_feats_decod.append(scores_df)
    # Plot average decoding scores of 5 splits
    #fig, ax = plt.subplots(1)
    #ax.plot(motor_epochs.times, scores.mean(0), label='score')
    #ax.axhline(mean_squared_error(y,[np.mean(y)]*len(y)), color='k', linestyle='--', label='chance')
    #ax.axvline(0, color='k')
    #plt.legend()

print('running all')
X = all_trials_tcs
y = ys

# prepare a series of classifier applied at each time sample
clf = make_pipeline(StandardScaler(),  # z-score normalization
                    LinearModel(LinearRegression()))
time_decod = SlidingEstimator(clf, scoring='r2')

# Run cross-validated decoding analyses:
scores = cross_val_multiscore(time_decod, X, y, cv=num_folds, n_jobs=None,verbose=False)
scores = np.abs(scores) # convert back to MSE
scores_df = pd.DataFrame(scores,index=np.arange(num_folds)+1).T
samples = list(scores_df.index)
scores_df = pd.melt(scores_df,var_name='fold',value_name='r2',ignore_index=False)
scores_df['sample_no'] = samples*num_folds
scores_df['source_region'] = 'all'
scores_df['hemi'] = 'both'

all_feats_decod.append(scores_df)
all_feats_decod = pd.concat(all_feats_decod).reset_index(drop=True)
all_feats_decod['time'] = (all_feats_decod['sample_no']-(1.25*128))/128 # sample to time


In [None]:
sns.set_style('white')

all_feats_decod['source region'] = all_feats_decod.source_region.replace(rename_dict, regex=True)
g = sns.relplot(
    data=all_feats_decod[(all_feats_decod.time>-1) & (all_feats_decod.hemi!='both')], x="time", y="r2",
    hue="source region", col="hemi",
    kind="line",ci=None,palette=rel_labels_color_map
)
g.axes.flatten()[0].set_title('Left Hemisphere')
g.axes.flatten()[1].set_title('Right Hemisphere')
g.set_ylabels("$R^2$", clear_inner=False)
g.set_axis_labels("time (relative to motor event)")

# Plot time series

## All time series by freq

### Sample bands for single trial

In [None]:
from mpl_toolkits.axes_grid1 import (make_axes_locatable, ImageGrid,
                                     inset_locator)

for band in band_intervals:
        src_sample_ts = mne.minimum_norm.apply_inverse(motor_epochs.average().filter(band[0],band[1],verbose=False), inverse_operator,
                                            lambda2=1.0 / snr ** 2, verbose=False,
                                            method="eLORETA", pick_ori="normal")
        tmp = mne.extract_label_time_course(
        src_sample_ts, rel_labels[:-1], inverse_operator['src'], mode='mean_flip', allow_empty=True,
        return_generator=True, verbose=False)[0]

        fig, ax = plt.subplots(figsize=(10,2),dpi=300)
        ax.plot(tmp,c='k', linewidth=3)
        ax.axis('off')
        extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
        fig.show()

#### Single post-steer event for prediction

In [None]:
sns.set_style('white')
fig, ax = plt.subplots(figsize=(10,2),dpi=300)
ax.plot(motor_dfs.iloc[0]['post_steer_event_raw'],'r')
ax.set_xlabel('time (samples)')
ax.set_ylabel('|$\Delta$|deg')
#ax.axis('off')
extent = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
fig.show()


### Source plots

In [None]:
bands = [4, 8, 15, 32, 55]
band_intervals = list(zip(bands[:-1], bands[1:])) # define bands
even_range = np.linspace(-1,0,3) # define time range
average_within_range = False
screenshots = defaultdict(list)
add_labels = False

bands = [4, 8, 15, 32, 55]
band_intervals = list(zip(bands[:-1], bands[1:]))
for hemi in ['lh','rh']:
    rel_labels_to_add = [label for label in rel_labels if hemi in label.name]
    avg_stc = mne.minimum_norm.apply_inverse(motor_epochs.average(), inverse_operator,
                                            lambda2=1.0 / snr ** 2, verbose=False,
                                            method="eLORETA", pick_ori="normal")
    if average_within_range:
            for indx, t in enumerate(even_range[1:]):
                avg_stc_copy = avg_stc.copy()
                brain = avg_stc_copy.crop(even_range[indx], even_range[indx+1]).mean().plot(subjects_dir=subjects_dir, initial_time=t,
                                                surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                                smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=False,show_traces=False,
                                                clim=dict(kind='value', pos_lims=[0,0.784058961088799e-13,1.87367476e-13]),
                                                views='medial',
                                                add_data_kwargs=dict(
                                                    colorbar_kwargs=dict(label_font_size=20,height=0.25,n_labels=5)))
                if add_labels:
                    for rla in rel_labels_to_add:
                        brain.add_label(rla, borders=1)
                screenshots[(0,128), hemi].append(brain.screenshot())
                brain.close()
    else:
        for indx, t in enumerate(even_range):
            brain = avg_stc.plot(subjects_dir=subjects_dir, initial_time=t,
                                            surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                            smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=False,show_traces=False,
                                            clim=dict(kind='value', pos_lims=[0,0.784058961088799e-13,1.87367476e-13]),
                                            add_data_kwargs=dict(
                                                colorbar_kwargs=dict(label_font_size=20,height=0.25,n_labels=5)))

            if add_labels:
                for rla in rel_labels_to_add:
                    brain.add_label(rla, borders=1)
            screenshots[(0,128), hemi].append(brain.screenshot())
            brain.close()
    for band in band_intervals:
        avg_stc = mne.minimum_norm.apply_inverse(motor_epochs.average().filter(band[0],band[1]), inverse_operator,
                                            lambda2=1.0 / snr ** 2, verbose=False,
                                            method="eLORETA", pick_ori="normal")
        if average_within_range:
            for indx, t in enumerate(even_range[1:]):
                avg_stc_copy = avg_stc.copy()
                brain = avg_stc_copy.crop(even_range[indx], even_range[indx+1]).mean().plot(subjects_dir=subjects_dir, 
                                                surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                                smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=False,show_traces=False,
                                                clim=dict(kind='value', pos_lims=[0,0.784058961088799e-13,1.87367476e-13]),
                                                add_data_kwargs=dict(
                                                    colorbar_kwargs=dict(label_font_size=20,height=0.25,n_labels=5)))

                # to help orient us, let's add a parcellation (red=auditory, green=motor,
                # blue=visual)
                if add_labels:
                    for rla in rel_labels_to_add:
                        brain.add_label(rla, borders=1)
                screenshots[band, hemi].append(brain.screenshot())
                brain.close()
        else:
            for indx, t in enumerate(even_range):
                avg_stc_copy = avg_stc.copy()
                brain = avg_stc_copy.plot(subjects_dir=subjects_dir, initial_time=t,
                                                surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                                smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=False,show_traces=False,
                                                clim=dict(kind='value', pos_lims=[0,0.784058961088799e-13,1.87367476e-13]),
                                                add_data_kwargs=dict(
                                                    colorbar_kwargs=dict(label_font_size=20,height=0.25,n_labels=5)))

                # to help orient us, let's add a parcellation (red=auditory, green=motor,
                # blue=visual)
                if add_labels:
                    for rla in rel_labels_to_add:
                        brain.add_label(rla, borders=1)
                screenshots[band, hemi].append(brain.screenshot())
                brain.close()


In [None]:
ordered_screen_shots = []
titles = []
if average_within_range:
    start_times = even_range[:-1]
else:
    start_times = even_range
print(len(screenshots[((0,128),'lh')]))
for hemi in ['lh','rh']:
    for i in range(len(screenshots[(band,hemi)])):
        ordered_screen_shots.append(screenshots[((0,128),hemi)][i])
        titles.append(f"({hemi}) {round(start_times[i])}")
for band in band_intervals:
    for hemi in ['lh','rh']:
        for i in range(len(screenshots[(band,hemi)])):
            ordered_screen_shots.append(screenshots[(band,hemi)][i])
            if band == (4,8):
                titles.append(f"Theta ({hemi}) {round(start_times[i])}")
            elif band == (8,15):
                titles.append(f"Alpha ({hemi}) {round(start_times[i])}")
            elif band == (15,32):
                titles.append(f"Beta ({hemi}) {round(start_times[i])}")
            elif band == (32,55):
                titles.append(f"Gamma ({hemi}) {round(start_times[i])}")

In [None]:
from mpl_toolkits.axes_grid1 import (make_axes_locatable, ImageGrid,
                                     inset_locator)

fig = plt.figure(figsize=(10, 20),dpi=300)
axes = ImageGrid(fig, 111, nrows_ncols=(5,6), axes_pad=0)
for ax, image, title in zip(axes, ordered_screen_shots,
                            titles):
    ax.axis('off')
    nonwhite_pix = (image != 255).any(-1)
    nonwhite_row = nonwhite_pix.any(1)
    nonwhite_col = nonwhite_pix.any(0)
    cropped_screenshot = image[nonwhite_row][:, nonwhite_col]
    # Hide grid lines
    ax.grid(False)

    # Hide axes ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(cropped_screenshot)

### Colorbar

In [None]:
avg_me = motor_epochs.average()
avg_me.data = avg_me.data*1e6
avg_stc = mne.minimum_norm.apply_inverse(avg_me, inverse_operator,
                                            lambda2=1.0 / snr ** 2, verbose=False,
                                            method="eLORETA", pick_ori="normal")
brain = avg_stc.plot(subjects_dir=subjects_dir, initial_time=t,views=['medial'],smoothing_steps=10,
                                            surface='inflated', hemi='lh', size=(1000, 500),background='white', cortex='low_contrast',
                                            time_viewer=False,colorbar=True,show_traces=False,clim=dict(kind='value', pos_lims=[0,0.784058961088799e-7,1.87367476e-7]),
                                            add_data_kwargs=dict(
                                                colorbar_kwargs=dict(label_font_size=40,height=0.25,n_labels=5,
                                                                     position_x=0.05, position_y=-0.05)))
cbar_screenshot = brain.screenshot()
brain.close()
fig = plt.figure(figsize=(10, 20),dpi=300)
axes = ImageGrid(fig, 111, nrows_ncols=(1,1), axes_pad=0)
ax = axes[0]
ax.axis('off')
nonwhite_pix = (cbar_screenshot != 255).any(-1)
nonwhite_row = nonwhite_pix.any(1)
nonwhite_col = nonwhite_pix.any(0)
cropped_screenshot = cbar_screenshot[nonwhite_row][:, nonwhite_col]
# Hide grid lines
ax.grid(False)

# Hide axes ticks
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(cbar_screenshot)

## Correlation plots

In [None]:
rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package,regions_in_activity='all') # for all regions use 'all'
# note that new_cols is not used just to manually set color
rel_labels_color_map = {'dlpfc': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
'dorsal_anterior_cingulate': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
'ventral_anterior_cingulate': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
'striate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
'parastriate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354),
'DLPFC': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
'dACC': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
'vACC': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
'V1': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
'V2': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354)}
for idx,l in enumerate(rel_labels):
    l.name = f"{rel_mappings[l.name]}-{l.hemi}"
    for r in rel_labels_color_map:
        if r.lower() == l.name.split('-')[0]:
            l.color = rel_labels_color_map[r]

In [None]:
from statsmodels.stats.multitest import multipletests

multipletests([8e-8, 1.4e-9, 1.95e-7, 0.003023],method='fdr_bh')

In [None]:
from scipy.stats import spearmanr
from statsmodels.stats.multitest import multipletests
# gather all correlations and vertices by label to fit in memory

def get_correlated_acts(output_dir,overwrite=False):
    if not overwrite and os.path.isfile(f"{output_dir}correlated_activations.pickle"):
            corrs,pvals = pickle.load(open(f"{output_dir}correlated_activations.pickle", 'rb'))
            return corrs,pvals
    else:
        corrs = {}
        pvals = {}
        for l in rel_labels:
            avg_stcs = mne.minimum_norm.apply_inverse_epochs(motor_epochs, inverse_operator,
                                                lambda2=1.0 / snr ** 2, verbose=False,
                                                method="eLORETA", pick_ori="normal",label=l)
            all_stcs_np = np.array([x.data for x in avg_stcs])
            for m in ['measures.rmssd','Left Pupil Diameter','NSLR_mean_duration.Saccade']:
                corr_measure_data = motor_epochs.metadata[m]
                ress = []
                for i in range(all_stcs_np.shape[1]): # faster to do it this way
                    res = spearmanr(a=all_stcs_np[:,i,1], b=corr_measure_data,nan_policy='omit')
                    ress.append(res)
                corrs[(l.name, m)] = np.array([r.correlation for r in ress])
                pvals[(l.name, m)] = np.array([r.pvalue for r in ress])
        with open(f"{output_dir}correlated_activations.pickle", 'wb') as handle_ica:
            pickle.dump([corrs,pvals], handle_ica, protocol=pickle.HIGHEST_PROTOCOL)
    return corrs,pvals
corrs,pvals = get_correlated_acts(output_dir,overwrite=False)

In [None]:
hemi_mapping = {0:'lh', 1: 'rh'} # the index of the list on the vertices object and the hemi it maps to (according to doc)
avg_stc = mne.minimum_norm.apply_inverse(motor_epochs.average(), inverse_operator,
                                        lambda2=1.0 / snr ** 2, verbose=False,
                                        method="eLORETA", pick_ori="normal")
avg_stc_vertices = avg_stc.vertices[0]
avg_stc_df = pd.DataFrame(avg_stc.data[:,0])
avg_stc_df['vertex'] = list(avg_stc_vertices) * 2
avg_stc_df['hemi'] = ['lh'] * len(avg_stc_vertices) + ['rh'] * len(avg_stc_vertices)
avg_stc_df = avg_stc_df.rename(columns={0:'activation'})

avg_stcs_corr = {}
for m in ['measures.rmssd','Left Pupil Diameter','NSLR_mean_duration.Saccade']:
    avg_stc_df['activation'] = 0 # default to 0
    for l in rel_labels:
        this_pvals = pvals[(l.name,m)]
        this_corrs = corrs[(l.name,m)].copy()
        mask = (avg_stc_df.hemi == l.hemi) & (avg_stc_df.vertex.isin(l.vertices))
        avg_stc_df.loc[mask,'activation'] = this_corrs
    avg_stc_replaced = avg_stc.copy()
    avg_stc_replaced.data = np.expand_dims(avg_stc_df.activation,1)
    avg_stcs_corr[m] = avg_stc_replaced

#avg_stcs_corr['measures.rmssd'].plot(clim=dict(kind='value', pos_lims=[0,0.03,0.10]))

In [None]:
rel_labels_sig, rel_mappings_sig = get_relevant_labels_mappings(path_to_base_package,regions_in_activity=significant_regions) # for all regions use 'all'
rel_labels_color_map = {'dlpfc': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
'dorsal_anterior_cingulate': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
'ventral_anterior_cingulate': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
'striate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
'parastriate': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354),
'DLPFC': (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), 
'dACC': (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), 
'vACC': (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), 
'V1': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), 
'V2': (0.8325928529853291, 0.5253446757844744, 0.6869376931865354)}
for idx,l in enumerate(rel_labels_sig):
    l.name = f"{rel_mappings[l.name]}-{l.hemi}"
    for r in rel_labels_color_map:
        if r.lower() == l.name.split('-')[0]:
            l.color = rel_labels_color_map[r]
screenshots = defaultdict(list)
titles = []
for m in ['measures.rmssd','Left Pupil Diameter','NSLR_mean_duration.Saccade']:
    for hemi in ['lh','rh']:
        rel_labels_to_add = [label for label in rel_labels_sig if hemi in label.name]
        brain = avg_stcs_corr[m].plot(subjects_dir=subjects_dir, clim=dict(kind='value', pos_lims=[0,0.03,0.05]),surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                        smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=False,show_traces=False,
                                        add_data_kwargs=dict(
                                            colorbar_kwargs=dict(label_font_size=40,height=0.25,n_labels=5,
                                                                     position_x=0.05, position_y=.85)))
        if add_labels:
            for rla in rel_labels_to_add:
                brain.add_label(rla, borders=1)
        screenshots[m, hemi] = brain.screenshot()
        
        brain.close()

In [None]:
brain = avg_stcs_corr[m].plot(subjects_dir=subjects_dir, clim=dict(kind='value', pos_lims=[0,0.03,0.05]),surface='pial', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                        smoothing_steps=10, time_viewer=False,alpha=.75,colorbar=True,show_traces=False,
                                        add_data_kwargs=dict(
                                            colorbar_kwargs=dict(label_font_size=40,height=0.25,n_labels=5,
                                                                     position_x=0.05, position_y=.85)))
cbar_screenshot = brain.screenshot()
brain.close()
fig = plt.figure(figsize=(10, 20),dpi=300)
axes = ImageGrid(fig, 111, nrows_ncols=(1,1), axes_pad=0)
ax = axes[0]
ax.axis('off')
nonwhite_pix = (cbar_screenshot != 255).any(-1)
nonwhite_row = nonwhite_pix.any(1)
nonwhite_col = nonwhite_pix.any(0)
cropped_screenshot = cbar_screenshot[nonwhite_row][:, nonwhite_col]
# Hide grid lines
ax.grid(False)

# Hide axes ticks
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(cbar_screenshot)

In [None]:
ordered_screen_shots = []
titles = []
for hemi in ['lh','rh']:
    for m in ['measures.rmssd','Left Pupil Diameter','NSLR_mean_duration.Saccade']:
        ordered_screen_shots.append(screenshots[m, hemi])
        titles.append(f"{m} {hemi}")

In [None]:
from mpl_toolkits.axes_grid1 import (make_axes_locatable, ImageGrid,
                                     inset_locator)

fig = plt.figure(figsize=(10, 20),dpi=300)
axes = ImageGrid(fig, 111, nrows_ncols=(2,3), axes_pad=0)
for ax, image, title in zip(axes, ordered_screen_shots,
                            titles):
    ax.axis('off')
    print('title',title)
    nonwhite_pix = (image != 255).any(-1)
    nonwhite_row = nonwhite_pix.any(1)
    nonwhite_col = nonwhite_pix.any(0)
    cropped_screenshot = image[nonwhite_row][:, nonwhite_col]
    # Hide grid lines
    ax.grid(False)

    # Hide axes ticks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(cropped_screenshot)

In [None]:
empty_brain = avg_stcs_corr[m].copy()
empty_brain.data[:,:] = 0
brain = empty_brain.plot(subjects_dir=subjects_dir, clim=dict(kind='value', pos_lims=[0,0.05,0.10]),surface='inflated', hemi=hemi, size=(1000, 500),background='white', cortex='low_contrast',
                                        smoothing_steps=10, time_viewer=False,alpha=.1,colorbar=False,show_traces=False,
                                        add_data_kwargs=dict(
                                            colorbar_kwargs=dict(label_font_size=10)))
for rla in rel_labels_to_add:
            brain.add_label(rla, borders=3)
sshot = brain.screenshot()
brain.close()


In [None]:
nonwhite_pix = (sshot != 255).any(-1)
nonwhite_row = nonwhite_pix.any(1)
nonwhite_col = nonwhite_pix.any(0)
cropped_screenshot = sshot[nonwhite_row][:, nonwhite_col]
fig,axes = plt.subplots(figsize=(10,10),dpi=300)
axes.grid(False)
ax.set_xticks([])
plt.axis('off')
axes.get_xaxis().set_visible(False)
axes.get_yaxis().set_visible(False)

ax.set_yticks([])
axes.imshow(cropped_screenshot)

# Functional connectivity

In [None]:
from pathlib import Path

from mne import (make_forward_solution, setup_source_space,
                 setup_volume_source_space)
from mne.io import read_raw_fif
from mne.minimum_norm import apply_inverse_epochs, make_inverse_operator
from mne.viz import circular_layout
from mne_connectivity import spectral_connectivity_epochs, spectral_connectivity_time
from mne_connectivity.viz import plot_connectivity_circle

In [None]:
def get_connectivity_plot(label_ts, output_dir, fmin, fmax, rel_labels, fig_title = 'Motor', connectivity_type='epochs'):
    # We compute the connectivity in the alpha band and plot it using a circular
    # graph layout
    fmin = fmin
    fmax = fmax
    sfreq = motor_epochs.info['sfreq']  # the sampling frequency
    if connectivity_type == 'epochs':
        con = spectral_connectivity_epochs(
            label_ts, method='pli', mode='multitaper', sfreq=sfreq, fmin=fmin,
            fmax=fmax, faverage=True, mt_adaptive=True, n_jobs=5,verbose=False)
    elif connectivity_type == 'time':
        con = spectral_connectivity_time(
        label_ts, method='pli', mode='multitaper', sfreq=sfreq, n_jobs=5,verbose=False)

    labels = rel_labels
    # read colors
    node_colors = [label.color for label in labels]
    # We reorder the labels based on their location in the left hemi
    label_names = [label.name for label in labels]
    lh_labels = [name for name in label_names if name.endswith('lh')]
    rh_labels = [name for name in label_names if name.endswith('rh')]

    # Get the y-location of the label
    label_ypos_lh = list()
    for name in lh_labels:
        idx = label_names.index(name)
        ypos = np.mean(labels[idx].pos[:, 1])
        label_ypos_lh.append(ypos)
    try:
        idx = label_names.index('Brain-Stem')
    except ValueError:
        pass
    else:
        ypos = np.mean(labels[idx].pos[:, 1])
        lh_labels.append('Brain-Stem')
        label_ypos_lh.append(ypos)


    # Reorder the labels based on their location
    lh_labels = [label for (yp, label) in sorted(zip(label_ypos_lh, lh_labels))]

    # For the right hemi
    rh_labels = [label[:-2] + 'rh' for label in lh_labels
                if label != 'Brain-Stem' and label[:-2] + 'rh' in rh_labels]

    # Save the plot order
    node_order = lh_labels[::-1] + rh_labels

    node_angles = circular_layout(label_names, node_order, start_pos=90,
                                group_boundaries=[0, len(label_names) // 2])


    # Plot the graph using node colors from the FreeSurfer parcellation. We only
    # show the 300 strongest connections.
    conmat = con.get_data(output='dense')[:, :, 0]
    np.save(f"{output_dir}/connectivity/{fig_title}_conn",conmat)
    return conmat, node_colors, node_angles, con
    

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import cmath
import random

def get_phase_shift(signal,t0=0, dt=1/128,plot=False):
    t = np.linspace( t0, t0+len(signal)*dt, len(signal), endpoint=False )
    ## Fourier transform of real valued signal
    signalFFT = np.fft.rfft(signal)

    ## Get Power Spectral Density
    signalPSD = np.abs(signalFFT) ** 2
    signalPSD /= len(signalFFT)**2

    ## Get Phase
    signalPhase = np.angle(signalFFT)

    ## Phase Shift the signal +90 degrees
    rand_phase = random.uniform(0,np.pi/2)
    newSignalFFT = signalFFT * cmath.rect( 1., rand_phase ) # pick random phase between 0 and 360 degrees
    
    ## Reverse Fourier transform
    newSignal = np.fft.irfft(newSignalFFT)

    ## Uncomment this line to restore the original baseline
    newSignal += signalFFT[0].real/len(signal)
    
    if plot:
        ## Get frequencies corresponding to signal 
        fftFreq = np.fft.rfftfreq(len(signal), dt)
        plt.figure( figsize=(10, 4) )

        ax1 = plt.subplot( 1, 2, 1 )
        ax1.plot( t[:-1], signal[:-1], label='signal')
        ax1.plot( t[:-1], newSignal, label='new signal')
        ax1.set_ylabel( 'Signal' )
        ax1.set_xlabel( 'time' )
        ax1.legend()

        ax2 = plt.subplot( 1, 2, 2 )
        ax2.plot( fftFreq, signalPSD )
        ax2.set_ylabel( 'Power' )
        ax2.set_xlabel( 'frequency' )

        ax2b = ax2.twinx()
        ax2b.plot( fftFreq, signalPhase, alpha=0.25, color='r' )
        ax2b.set_ylabel( 'Phase', color='r' )


        plt.tight_layout()

        plt.show()
    return newSignal, rand_phase

phase_shifted_signals = np.zeros((ordered_input_tcs.shape[0],ordered_input_tcs.shape[1],ordered_input_tcs.shape[2]-1))
for ep in range(ordered_input_tcs.shape[0]):
    for ch in range(ordered_input_tcs.shape[1]):
        this_sig = ordered_input_tcs[ep,ch,:]
        newSignal, rand_phase = get_phase_shift(this_sig,plot=True)
        phase_shifted_signals[ep,ch,:] = newSignal
        sdfasdf

In [None]:
bands = [(4.,8.,'Theta'), (8.,15.,'Alpha'), (15.,32.,'Beta'), (32., 55., 'Gamma')]
cond = 'All'
con_res = {}
con_res_phase_random = {}
for band in bands:
    print('band', band)
    con_res[band] = get_connectivity_plot(ordered_input_tcs,output_dir,fmin = band[0], fmax = band[1], rel_labels = rel_labels, fig_title = f"{band[2]} {cond} Motor")
    # con_res_phase_random[band] = get_connectivity_plot(phase_shifted_signals,output_dir,fmin = band[0], fmax = band[1], fig_title = f"{band[2]} {cond} Phase_Random")

## Wilcoxon rank-sum test for connectivity differences

In [None]:
from scipy.stats import wilcoxon
from statsmodels.stats.multitest import multipletests
def get_wilcoxon_fdr_corrs(orig,shift):
    #find p-value for two-tailed test
    get_lower_trig = lambda c: c[np.tril_indices(len(c), k = -1)]
    p_vals = []
    for band in bands:
        conn_vals_orig = get_lower_trig(orig[band][3].get_data(output='dense')[:,:,0])
        conn_vals_phase_shift = get_lower_trig(shift[band][3].get_data(output='dense')[:,:,0])
        res = wilcoxon(conn_vals_phase_shift-conn_vals_orig)
        p_vals.append(res.pvalue)
        print('band', band, 'sum of delta ranks', res.statistic, 'p val', res.pvalue)
    print('FDR results')
    print(multipletests(p_vals,method='fdr_bh'))

get_wilcoxon_fdr_corrs(con_res, con_res_phase_random)

## Connectivity Plot

In [None]:
def display_connectivity_plot(rel_labels,con_res_input,band_titles, label_first_plot=True, use_fixed_colors = False, saved_fig_name='connectivity.png',vmin=0,vmax=.16):
    rel_labels = get_readable_labels(rel_labels)
    if use_fixed_colors:
        node_colors = [[(0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037)],
                  [(0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037)],
                  [(0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037)],
                  [(0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.8325928529853291, 0.5253446757844744, 0.6869376931865354), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.49498740849493095, 0.4799034869159042, 0.21147789468974837), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09406611799930162, 0.3578871412608098, 0.2837709711722866), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037), (0.09854228363950114, 0.07115215572295082, 0.16957891809124037)]]
    else:
        node_colors = [con_res[b][1] for b in con_res]
    rename_dict = OrderedDict({'-lh': ' LH',
                    '-rh': ' RH',
                'dorsal_anterior_cingulate': 'dACC',
                'ventral_anterior_cingulate': 'vACC',
                'dlpfc': 'DLPFC',
                'parastriate': 'V2',
                    'striate': 'V1',
                    'intermediate_frontal': 'MFG',
                })
    fig, axes = plt.subplots(ncols=4, nrows=1, figsize=(30,20),dpi=300,
                            layout="constrained",subplot_kw={'projection': 'polar'})
    for ax, band, index, node_color, btitle in zip(axes.ravel(), bands,list(range(len(bands))),node_colors, band_titles):
        
        if index+1 == len(bands):
            colorbar=True
        else:
            colorbar=False
        
        # Hide grid lines
        ax.grid(False)

        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])
        renamed_labels = []
        for r in [label.name for label in rel_labels]:
            orig_r = r
            for key, value in rename_dict.items():
                if key in r:
                    orig_r = orig_r.replace(key, value)
            renamed_labels.append(orig_r)
        
        if index == 0 and not label_first_plot:
            renamed_labels = ['']*len(rel_labels)
        elif index > 0:
            renamed_labels = ['']*len(rel_labels)
        result_img = plot_connectivity_circle(con_res_input[band][0], renamed_labels, n_lines=300,facecolor = 'white',textcolor='black',node_edgecolor='white',interactive=False,
                                colormap='pink_r', node_angles=con_res_input[band][2], node_colors=node_color,fontsize_names=40,fontsize_title=60,padding=2,show=False,
                                title=f"{btitle}", fontsize_colorbar=40,ax=ax,fig=fig, colorbar=colorbar,colorbar_pos=(0, 0.1),colorbar_size=.5,linewidth=20,vmin=vmin, vmax=vmax,
                                node_linewidth=4)
    if saved_fig_name:
        plt.savefig(f'{saved_fig_name}',dpi=300)
display_connectivity_plot(rel_labels,con_res,band_titles = [r'$\theta$',r'$\alpha$',r'$\beta$',r'$\gamma$'],saved_fig_name=None) # saved_fig_name='connectivity.png'

# Difference between high and low density

In [None]:
all_dfs = []
for name,group in all_tcs.groupby(['pid','session','trial_start_time']):
    rel_trial = group.iloc[0]
    rel_trial_info = motor_epochs.metadata[(motor_epochs.metadata.ppid == rel_trial.pid) & (motor_epochs.metadata.session == rel_trial.session) & (motor_epochs.metadata.trial_start_time == rel_trial.trial_start_time)]
    group['trial_damage'] = rel_trial_info['trial_damage'].item()
    group['trial_opacity'] = rel_trial_info['density'].item()
    all_dfs.append(group)
modified_tcs_df = pd.concat(all_dfs)
modified_tcs_df['opacity_bin'] = modified_tcs_df.groupby(['pid'])['trial_opacity'].transform(
    lambda x: pd.qcut(x, 2, labels=['low', 'high']))


In [None]:
rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package,regions_in_activity=significant_regions) # for all regions use 'all'
ordered_input_tcs_low_density,all_tcs_tcs_low_density = get_epoched_tcs(modified_tcs_df[modified_tcs_df.source_region.isin(significant_regions) & (modified_tcs_df.opacity_bin == 'low')], rel_labels, rel_mappings)
ordered_input_tcs_high_density,all_tcs_tcs_high_density = get_epoched_tcs(modified_tcs_df[modified_tcs_df.source_region.isin(significant_regions) & (modified_tcs_df.opacity_bin == 'high')], rel_labels, rel_mappings)

### Connectivity

In [None]:
bands = [(4.,8.,'Theta'), (8.,15.,'Alpha'), (15.,32.,'Beta'), (32., 55., 'Gamma')]
cond = 'All'
con_res_low_opacity = {}
con_res_high_opacity = {}
for band in bands:
    print('band', band)
    con_res_low_opacity[band] = get_connectivity_plot(ordered_input_tcs_low_density,output_dir,fmin = band[0], fmax = band[1], rel_labels = rel_labels, fig_title = f"{band[2]} {cond} Motor Low Opacity")
    con_res_high_opacity[band] = get_connectivity_plot(ordered_input_tcs_high_density,output_dir,fmin = band[0], fmax = band[1], rel_labels = rel_labels, fig_title = f"{band[2]} {cond} Motor High Opacity")


### Conn Plots and significance

In [None]:
conn_diff = {}
for c in con_res_high_opacity:
      conn_diff[c] = list(con_res_high_opacity[c])
      conn_diff[c][0] = con_res_high_opacity[c][0]-con_res_low_opacity[c][0]


In [None]:
display_connectivity_plot(rel_labels,conn_diff,band_titles = [r'$\theta$',r'$\alpha$',r'$\beta$',r'$\gamma$'],use_fixed_colors=True,label_first_plot=False, saved_fig_name=None,vmin=-0.1,vmax=.1)
get_wilcoxon_fdr_corrs(con_res_low_opacity,con_res_high_opacity)

## Difference between damage and no damage

In [None]:
pp_agg = modified_tcs_df.groupby(['trial_start_time']).mean().reset_index()
pp_agg = pp_agg[(pp_agg.trial_damage > 0)].groupby(['pid']).count()
np.mean(pp_agg.trial_start_time),np.std(pp_agg.trial_start_time)

In [None]:
balanced_no_damage.shape, ordered_input_tcs_no_damage.shape, ordered_input_tcs_yes_damage.shape, ordered_input_tcs_low_density.shape, ordered_input_tcs_high_density.shape

In [None]:
rel_labels, rel_mappings = get_relevant_labels_mappings(path_to_base_package,regions_in_activity=significant_regions) # for all regions use 'all'
ordered_input_tcs_no_damage,all_tcs_tcs_no_damage = get_epoched_tcs(modified_tcs_df[modified_tcs_df.source_region.isin(significant_regions) & (modified_tcs_df.trial_damage == 0)], rel_labels, rel_mappings)
ordered_input_tcs_yes_damage,all_tcs_tcs_yes_damage = get_epoched_tcs(modified_tcs_df[modified_tcs_df.source_region.isin(significant_regions) & (modified_tcs_df.trial_damage > 0)], rel_labels, rel_mappings)

# select randomly
random.seed(10)
rows_id = random.sample(range(0, ordered_input_tcs_no_damage.shape[0]-1), ordered_input_tcs_yes_damage.shape[0]-1)
balanced_no_damage = ordered_input_tcs_no_damage[rows_id,:,:]

In [None]:
bands = [(4.,8.,'Theta'), (8.,15.,'Alpha'), (15.,32.,'Beta'), (32., 55., 'Gamma')]
cond = 'All'
con_res_no_damage = {}
con_res_yes_damage = {}
for band in bands:
    print('band', band)
    con_res_no_damage[band] = get_connectivity_plot(balanced_no_damage,output_dir,fmin = band[0], fmax = band[1], rel_labels = rel_labels, fig_title = f"{band[2]} {cond} Motor No Damage")
    con_res_yes_damage[band] = get_connectivity_plot(ordered_input_tcs_yes_damage,output_dir,fmin = band[0], fmax = band[1], rel_labels = rel_labels, fig_title = f"{band[2]} {cond} Motor Yes Damage")


In [None]:

conn_diff_dam = {}
for c in con_res_high_opacity:
      conn_diff_dam[c] = list(con_res_yes_damage[c])
      conn_diff_dam[c][0] = con_res_yes_damage[c][0]-con_res_no_damage[c][0]

In [None]:
display_connectivity_plot(rel_labels,conn_diff_dam,band_titles = [r'$\theta$',r'$\alpha$',r'$\beta$',r'$\gamma$'],label_first_plot=False,use_fixed_colors=True,saved_fig_name=None,vmin=-0.16, vmax=0.61)
get_wilcoxon_fdr_corrs(con_res_no_damage,con_res_yes_damage)