In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from mne.stats import permutation_cluster_1samp_test

from IPython.display import clear_output
import copy

import os
os.chdir('..')
from _parameters import *

In [None]:
mode = ['respL', 'respR', 'itemL', 'itemR']

tfr_samerev = []

for s in subjects:

    print(f'Loading subject {s}')
    cond = {'same': None, 'rvrs': None}

    for c in cond.keys():

        print(f'Loading condition {c}')
        tfr = []
        for m in mode:
            fname = dirs['tfr'] + f'/tfr_multi_enc1_{c}_{m}_s{s}.h5'
            tfr_file = mne.time_frequency.read_tfrs(fname)[0]
            tfr.append(tfr_file)
            
        cond[c] = mne.grand_average(tfr)

    tfr = copy.deepcopy(cond['same'])

    tfr.data = ((cond['same'].data - cond['rvrs'].data) / (cond['same'].data + cond['rvrs'].data)) * 100

    tfr_samerev.append(tfr)

    clear_output()

tfr_samerev_avg = mne.grand_average(tfr_samerev)

In [None]:
# multiplot over channels
tfr_samerev_avg.plot_topo(vmin=-10, vmax=10)

In [None]:
# topography plot
tfr_samerev_avg.plot_topomap(fmin=8, fmax=12, tmin=1.5, tmax=3, vmin=-3, vmax=3, colorbar=False)

In [None]:
# Pz TFR per subject

samerev_dat = []

for s in range(len(subjects)):
    dat = tfr_samerev[s].pick('Pz').data.mean(0)
    samerev_dat.append(dat)

In [None]:
# Run clusterstat

data = np.asarray(samerev_dat)

t, clust, p, h0 = permutation_cluster_1samp_test(data, n_permutations=10000, out_type= 'mask')

# Create mask
mask = np.asarray(clust)[p<.05]
mask = np.kron(np.squeeze(mask), np.ones((10,10)))

stat = {
    'mask': mask,
    'sig': p
}


In [None]:
# Plot settings
plt.rcdefaults()

plt.rcParams['font.family'] = 'Helvetica Neue'
plt.rcParams['svg.fonttype'] = 'none'

In [None]:
# Plot TFR
fig, axes = plt.subplots(figsize = (10,4))

extent=(-1, 4, 3, 40)

im = axes.imshow(np.mean(data, 0), cmap = 'RdBu_r', extent=extent, 
                    origin="lower", aspect="auto", 
                    vmin = -10, vmax = 10)

axes.axvline(0, color='grey', linestyle=':')
axes.axvline(1.5, color='grey', linestyle=':')

mask = stat['mask']

if len(mask.shape) == 3: # multiple masks
    for m in mask:
        axes.contour(m, colors='black', extent=extent, linewidths=.1, corner_mask=False, antialiased=False)

else:
    axes.contour(mask, colors='black', extent=extent, linewidths=.1, corner_mask=False, antialiased=False)

    cbar = plt.colorbar(im, ticks = [-10,0,10])
    cbar.ax.tick_params(labelsize = 12)

    axes.set_title('Same vs reversed (Pz)', size = 14)

    axes.set_xlabel('Time after $\mathregular{1^{st}}$ encoding display (s)', fontsize=14)
    axes.set_ylabel('Frequency (Hz)', fontsize=14)
    axes.tick_params(axis='both', labelsize=12)

    axes.set_xlim(-0.25, 3.5)

fname = dirs['plot'] + f'/TFR/Pz-same-rvrs.svg'
fig.savefig(fname, format = 'svg', dpi = 300, bbox_inches='tight', transparent=True)


In [None]:
def plot_svsr(data, freq, vlim, title):
    tmin, tmax = [0, 0.25]
    fig, axes = plt.subplots(1, 14, figsize=(20, 4))

    for i in range(14):
        data.plot_topomap(tmin=tmin, tmax=tmax, 
                          fmin=freq[0], fmax=freq[1], 
                          vmin=vlim[0], vmax=vlim[1], 
                          colorbar=False, axes = axes[i], show=False)
        axes[i].set_title(f'{tmin} - {tmax} s')
        tmin+=0.25; tmax+=0.25

    fig.suptitle(title)
    plt.show()
    
    fname = dirs['plot'] + f'/topo/lvsr-{title}.svg'
    fig.savefig(fname, format = 'svg', dpi = 300, bbox_inches='tight', transparent=True)

plot_svsr(tfr_samerev_avg, (8,12), (-5,5), 'same-rvrs')