In [None]:
import warnings
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import pandas as pd
import mne
from tqdm import tqdm
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
from ipywidgets import interact
import ipywidgets as widgets
warnings.filterwarnings('ignore')

# Load data

In [None]:
chlist = None#path to channels list, npy
data = None #path to data, npy
time = np.r_[1:9:0.004][1000:]

montage = mne.channels.make_standard_montage('standard_1020')
info = mne.create_info(chlist.tolist(), 1000, ch_types='eeg',)
info.set_montage(montage)

sensor_adjacency, ch_names = mne.channels.find_ch_adjacency(
    info, 'eeg')

In [None]:
sessions = {
    'Im1': 0,
    'Im2': 1
}

subs = np.r_[:15]
# days = np.r_[:2]
sessions = np.r_[:2]
crit_muls = [0, 1, 2, 5]

index = pd.MultiIndex.from_product([sessions, subs, crit_muls])
conds = np.array(index.to_list())
results = []

for session, sub, crit_mul in tqdm(conds):
    day = 0
    a = data[day, session, sub]
    # a = np.rollaxis(a, 1)
    # a = np.rollaxis(a, 1, 3)
    a_mask = ~(np.isnan(a).any(-1).any(-1))
    a = a[a_mask]
    # print(f'{day} {session} {sub}: {np.isnan(a).any()}')
    
    alpha = 0.05/(2**int(crit_mul))
    thresh = -stats.t.ppf(q=1 - alpha, df=a.shape[0] - 1)
    
    if a.shape[0] == 0:
        res = None
    else:
        res = mne.stats.spatio_temporal_cluster_1samp_test(a, n_permutations=2000,
                                                         tail=-1, n_jobs=4, out_type='mask',
                                                         threshold=thresh,
                                                         adjacency=sensor_adjacency, stat_fun=mne.stats.ttest_1samp_no_p,
                                                         verbose=False)
    results.append(res)

# new_results = []
# for i in results:
#     if i is None:
#         new_results.append(None)
#     else:
#         T_obs, clusters, cluster_p_values, H0 = i
#         new_results.append((T_obs, cluster_p_values, H0))

npresults = np.array(results, dtype=object)
npresults = npresults.reshape(2, 15, 4, -1)
np.save('../data/TMS_TIME-CH-TRIAL_DAY2_IM1-2_SHAM_MOTOR-AREA_BCORR_POST_STAT-BY-SUB_NEG.npy', npresults)

In [None]:
from itertools import product
d2tms_stat = np.load('../data/TMS_TIME-CH-TRIAL_DAY2_IM1-2_MOTOR-AREA_BCORR_POST_STAT-BY-SUB_NEG.npy', allow_pickle=True)
d2sham_stat = np.load('../data/TMS_TIME-CH-TRIAL_DAY2_IM1-2_SHAM_MOTOR-AREA_BCORR_POST_STAT-BY-SUB_NEG.npy', allow_pickle=True)
crit_muls = [0, 1, 2, 5]
arrs = [d2tms_stat, d2sham_stat]
conds = ["TMS", "Sham"]
sessions = ['Im1', 'Im2']
res = []
for cond, sess in product([0, 1], [0, 1]):
    for crit_mul in crit_muls:
        a = 14 if cond == 0 else 15
        print(conds[cond], sessions[sess], a)
        alpha = 0.05/(2**int(crit_mul))
        thresh = -stats.t.ppf(q=1 - alpha, df=a - 1)
        print(crit_mul, alpha, thresh)
        res.append((conds[cond], sessions[sess], alpha, thresh, a - 1))


In [None]:
df = pd.DataFrame(res, columns=['Condition', 'Task', 'Alpha', 't-threshold', 'df'])
df.to_csv("t-thresholds.csv", index=False)
df

In [None]:
d2tms_stat[0, 9, 0, 0]

In [None]:
def plot_stat_topomap(T_obs, clus, info, axes=None, cl_ind=0):
    timefreq_mask = np.any(clus, axis=-1)
    nanmask = np.zeros(clus.shape)
    nanmask[~clus] = np.nan
    T_obs_filt = T_obs + nanmask
    topo = np.nanmean((T_obs_filt).T[:,timefreq_mask], axis=-1)
    topo[np.isnan(topo)] = 0
    return mne.viz.plot_topomap(topo, info, axes=axes, show=False, names=info.ch_names, show_names=True, extrapolate='local')

def plot_topo_n_t(time, t_obs, cluster, info, time_tr_mask=None):
    if time_tr_mask is None:
        time_tr_mask = np.ones(time.shape, dtype=bool)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    ax1.plot(time[time_tr_mask], np.mean(t_obs*cluster, axis=-1)[time_tr_mask])
    ax1.set_ylabel('T value')
    ax1.set_xlabel('Time (sec)')
    ax1.grid()
    im, _ = plot_stat_topomap(t_obs[time_tr_mask], cluster[time_tr_mask], info, axes=ax2)
    plt.colorbar(im, ax=ax2)
    #     ax1.set_title(f'Cluster p-value = {cluster_p_values[ind]} #{ind}')
    #     fig.suptitle(f'Base vs Post: Subject {sub + 1} Day {day}; {session_name} session; {band_name} band', fontsize=20)
    fig.tight_layout()
    return fig

def local_mins(sig):
    sig_len = sig.shape[-1]
    local_min_points_mask = np.diff((np.diff(sig) < 0).astype(int)) < 0
    local_min_points_inds = np.r_[2:sig_len][local_min_points_mask]
    local_min_points_values = sig[local_min_points_inds]
    return local_min_points_inds, local_min_points_values

def plot_mins(time, sig, ax):
    tinds, vals = local_mins(sig)
    print(vals)
    mask = [True] + (np.diff(tinds) > 30).tolist()
    for tind, val in zip(tinds[mask], vals[mask]):
        timex = time[tind]
        ax.text(timex - 0.4, val, f'{timex:.3f}')

In [None]:
% matplotlib inline
label = 'TMS'
data, stat = res[label]
# sub = 4
name_dict = {}
critmuls_dict = {
    0.05 / (2 ** 0): 0,
    0.05 / (2 ** 1): 1,
    0.05 / (2 ** 2): 2,
    0.05 / (2 ** 5): 3
}


def plot_all(sub_real_ind, session, crit_mul_ind):
    session_name = ['Im1', 'Im2'][session]
    sub = sub_real_ind + 1
    a = data[0, session, sub_real_ind]
    a_s = np.nanstd(a, 0)  # std over trials <(tr, time, ch) -> (time, ch)>
    a = np.nanmean(a, 0)  # average over trials <(tr, time, ch) -> (time, ch)>
    # critmuls = [0.05/(2**0), 0.05/(2**1), 0.05/(2**2), 0.05/(2**5)]

    clus_ind = None
    de_time = None
    T_obs, clusters, cluster_p_values, _ = stat[1, session, sub_real_ind, crit_mul_ind]

    inds = np.arange(len(cluster_p_values), dtype=int)[cluster_p_values < 0.3][::][:]
    clusters_num = len(inds)

    print(cluster_p_values)

    fig, axes = plt.subplots(clusters_num, 2, figsize=(12, 4 * clusters_num))

    if clusters_num == 1:
        axes = [axes]

    for temp_ind, (ax1, ax2) in enumerate(axes):
        ind = inds[temp_ind]
        cluster = clusters[ind]
        pval = cluster_p_values[ind]

        chmask = cluster.any(0)

        t_obs_temp = T_obs.copy()  # for averaging over channels
        t_obs_temp = t_obs_temp * cluster
        t_obs_temp[~cluster] = np.nan
        t_obs_temp = np.nanmean(t_obs_temp, -1)
        t_obs_temp[np.isnan(t_obs_temp)] = 0
        f = 50
        t_obs_temp = np.convolve(t_obs_temp, np.ones(f), mode='same') / f

        a_mean = a[..., chmask].mean(-1)
        a_std = a_s[..., chmask].mean(-1)

        ax1_t = ax1.twinx()
        ax1_t.plot(time, t_obs_temp)
        append = ''
        if clus_ind is not None and ind == clus_ind:
            #             plot_mins(time, t_obs_temp, ax1_t)
            append = 'Выбранный кластер\n'
            val = t_obs_temp[np.abs(time - de_time).argmin()]
            ax1_t.text(de_time - 0.4, val, f'{de_time:.3f}')
        #         except IndexError:
        #             pass

        ax1.plot(time, a_mean, color='orange')
        ax1.fill_between(time, a_mean - a_std, a_mean + a_std, color='orange', alpha=0.1)
        ax1_t.set_ylabel('T значения')
        ax1.set_ylabel('Средняя энергия в альфа')
        ax1.set_xlabel('Время (сек)')
        ax1.grid()
        im, _ = plot_stat_topomap(T_obs, cluster, info, axes=ax2)
        plt.colorbar(im, ax=ax2)
        ax1.set_title(f'{append}P-значение кластера = {pval} #{ind}')
    #     fig.suptitle(f'Постстимул - престимул: Субъект {sub} День {label}; {session_name} session; High Alpha band; clusteralpha = {critmul}', fontsize=20)
    fig.tight_layout()
    plt.show();
    name_dict['dirname'] = f'../stat_by_subject_plots_may-2022-only-plot/{label}'
    name_dict['fname'] = f'base-vs-post_day-{label}_session-{session_name}_sub-{sub}_clus_any.pdf'
    name_dict['fig'] = fig


def save_fig(description):
    fig = name_dict['fig']
    dirname = name_dict['dirname']
    fname = name_dict['fname']
    fig.savefig(f'{dirname}/{fname}')


sub_slider = widgets.IntSlider(min=0, max=14, step=1, value=0)
session_slider = widgets.IntSlider(min=0, max=1, step=1, value=0)
crit_slider = widgets.IntSlider(min=-1, max=3, step=1, value=-1)
save_button = widgets.Button(description="Save fig")
save_button.on_click(save_fig)
display(save_button)

interact(plot_all,
         sub_real_ind=sub_slider,
         session=session_slider,
         crit_mul_ind=crit_slider,
         sas_button=save_button);