In [None]:
import sys,os,glob,warnings, heapq
import pandas as pd
import numpy as np
import pingouin as pg
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import mne, scipy.io
from scipy import io, stats, interpolate
sys.path.append("D:/Dropbox/Projects/featureReplay/misc/")
import draw_sig_contour as dsc

from mne.decoding import (SlidingEstimator, GeneralizingEstimator,
                          cross_val_multiscore, LinearModel, get_coef)
from sklearn.svm import LinearSVC, SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import LeaveOneOut, cross_val_score, KFold
warnings.filterwarnings('ignore')
# make plot interactive
%matplotlib qt 
%matplotlib inline

In [None]:
project_path = 'D:/Dropbox/Projects/featureReplay/'
SUBJECTS = ['S01','S02','S03','S04','S05','S06','S07','S08','S09','S10',
            'S11','S12','S13','S14','S15','S16','S17','S18','S19','S20',
            'S21','S22','S23','S24','S25','S26','S27','S28','S29','S30',
            'S31','S32','S33','S34','S35']

# !! Important: if you only select one subject, you must still write as [n-1:n]
selected_subj = SUBJECTS[:18]
print(['Running Subjects:'] + selected_subj)

# Get selected channel index (occipital)
chans_all       = np.loadtxt(project_path + 'data_v5/misc_data/channels_all.txt', dtype='str')
chans_occipital = np.loadtxt(project_path + 'data_v5/misc_data/occipital_channels.txt', dtype='str')
chans_idx       = np.where(np.in1d(chans_all, chans_occipital) == True)

# params for decoding
cv = LeaveOneOut() #5
clf = make_pipeline(StandardScaler(), LogisticRegression(C=1, penalty='l1', multi_class='ovr', solver='liblinear')) 
time = GeneralizingEstimator(clf, n_jobs=20, scoring='accuracy') # define temporal generalization decoding
time_decod = SlidingEstimator(clf, n_jobs=20, scoring='accuracy') # use in permutation, faster
accuracies = np.zeros((n_subjects, 4, 325, 325)) # true_labels, times, times

### Run decoding over time

In [None]:
for s in range(n_subjects):
    subj_id = selected_subj[s]
    print('>>> loading data for subject:', subj_id)
    
    ## load data
    cond_data = project_path + 'data_v5/%s/%s_modelTrain_epochs_all_resample250_ica-epo.fif' %(subj_id, subj_id)
    epochs_all = mne.read_epochs(cond_data, preload=True)
    X = epochs_all.get_data() # n_trials * n_channels * n_times
    y = epochs_all.events[:,2]
    
    ## select relevant channels
    XX = np.squeeze(X[:,chans_idx,:]) # select occipital cortex

    ## classification
    pred = cross_val_multiscore(time, XX, y, cv=cv, n_jobs=20)
    
    ## store accuracies for each true label
    for ilabel in range(4): # true label
        accuracies[s,ilabel,:,:] = np.mean(pred[y==(ilabel+1),:,:], axis=0) # average across all trials per label
    
# save data
data_path = project_path + 'data_v5/saved_source_data/'  
if not os.path.exists(data_path): os.makedirs(data_path)
np.save(data_path+'acc_rs250_ica_occipital_find_multiOptimalTime', accuracies)
np.save(data_path+'time_points_rs250', epochs_all.times)

### Loading data for plotting

In [None]:
# load data
data_path = project_path + 'data_v5/saved_source_data/'
acc = np.load(data_path + 'acc_rs250_ica_occipital_find_multiOptimalTime.npy') # n_subjects*n_times*n_times

acc = acc * 100 
times = np.load(data_path + 'time_points_rs250.npy') 

# calculate diagonal mean, SEM
acc_diag = np.zeros((n_subjects, 4, acc.shape[-1]))
for isubj in range(n_subjects):
    for ilabel in range(4):
        acc_diag[isubj,ilabel,:] = np.diag(acc[isubj,ilabel])

acc_mean = np.mean(acc_diag, axis=0)
acc_sem = stats.sem(acc_diag, axis=0)

# calculate optimal time for each subject
optimal_time_idx = np.zeros((n_subjects,4))
optimal_times = np.zeros((n_subjects,4))
acc_per_subj = np.zeros((n_subjects,4))
for isubj in range(n_subjects):
    for ilabel in range(4):
            
        optimal_time_idx[isubj,ilabel] = np.argmax(np.diag(acc[isubj,ilabel])[floor_thr:ceil_thr]) + floor_thr
        optimal_times[isubj,ilabel] = times[int(optimal_time_idx[isubj,ilabel])]
        acc_per_subj[isubj,ilabel] = acc_diag[isubj,ilabel,int(optimal_time_idx[isubj,ilabel])]
        print('%s label:%s optimal time: %s ms, index %s' %(selected_subj[isubj], ilabel, optimal_times[isubj,ilabel]*1000, int(optimal_time_idx[isubj,ilabel])))
    print('\n')

### Plot mean decoding accuracies across subjects

In [None]:
# general parameters
color1 = 'darkorange'
color2 = 'limegreen'

# start plotting here
fig, axs = plt.subplots(nrows=2,ncols=4,sharex=False,sharey=False,figsize=(26, 2*6))

for ilabel in range(4):
    
    ################# Plot the full matrix #################      
    # Calculate statistical thresholds
    t_obs, clusters, cluster_pv, h0 = mne.stats.spatio_temporal_cluster_1samp_test((acc[:,ilabel,:,:]-1/4.*100), n_permutations=1024, out_type='mask', n_jobs=20)  
    # format p_values to get same dimensionality as X
    p_values = np.ones_like(acc[0,ilabel,:,:])
    for clu, pval in zip(clusters, cluster_pv):
        p_values[clu] = pval
    mask = p_values < 0.05
    
    ax = axs[0, ilabel]
    im = dsc.plot_contour_image(acc[:,ilabel,:,:].mean(0), times, ax=ax, mask=mask, vmin=25, vmax=35,
                                draw_mask=False, draw_contour=True, colorbar=True,
                                draw_diag=True, draw_zerolines=True,
                                mask_alpha=.5);    
    ax.set_xlabel('Testing Time (s)', fontsize=14)
    ax.set_ylabel('Training Time (s)', fontsize=14)
    ax.set_title('Mean Temporal Generalization Ori%s' %(ilabel*90), fontsize=14)
    ax.axvline(0, color='k')
    ax.axhline(0, color='k')
    ax.tick_params(axis = 'both', which = 'major', direction='in', top=False, right=False, labelsize = 12)
#     plt.colorbar(im, ax=ax)

    ################# plot the diagonal line #################
    ## plot statistical lines
    t_obs_diag, cluster_diag, cluster_pv_diag, H0_diag = mne.stats.permutation_cluster_1samp_test((acc_diag[:,ilabel,:]-1/4.*100), n_permutations=1024, n_jobs=20)
    sig_idx_diag = np.array(np.where(cluster_pv_diag < 0.05)).ravel()
                
    ax = axs[1, ilabel]
    ax.axhline(100/4., color='k', linestyle='--', alpha=0.8, lw=1.5)
    ax.axvline(0, color='k', linestyle='--', alpha=0.8, lw=1.5) 
    
    if ilabel == 0:
        idx2 = heapq.nlargest(5, range(len(acc_mean[0])), key=acc_mean[0].__getitem__)[1] # get the index of second largest number
        ax.axvline(times[idx2], color='k', linestyle='-', alpha=0.8, lw=1.5) 
    else:
        ax.axvline(times[np.argmax(acc_mean[ilabel])], color='k', linestyle='-', alpha=0.8, lw=1.5) 

    ax.plot(times, acc_mean[ilabel], color=color1, lw=2)
    ax.fill_between(times, acc_mean[ilabel]-acc_sem[ilabel], acc_mean[ilabel]+acc_sem[ilabel], color=color1,edgecolor='none',alpha=.35)
    # plot statistical line
    for i in range(len(sig_idx_diag)):
        if len(times[cluster_diag[sig_idx_diag[i]]]) >= 10: # plot at least 10 continous sig. time points
            ax.plot(times[cluster_diag[sig_idx_diag[i]]], np.repeat(20,len(times[cluster_diag[sig_idx_diag[i]]])),'k-', lw=3.5)
            
    ax.set_title('Mean decoding performance Ori%s' %(ilabel*90), fontsize=14)
    ax.set_ylim([18, 42])
    ax.set_xlabel('Time relative to stimulus onset (s)', fontsize=14);
    ax.set_ylabel('Accuracies (%)', fontsize=14);
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis = 'both', which = 'major', direction='in', top=False, right=False, labelsize = 14)

## save figures
save_figure = False
if save_figure:
    fig_path = project_path + '/data_v5/saved_figures/modelTrain_rs250_ica_occipital/'
    if not os.path.exists(fig_path): os.makedirs(fig_path);
    fig.savefig(fig_path+"modelTrain_acc_mean_find_multiOptimalTime_raw.pdf", bbox_inches='tight',dpi=300)

### Plot smoothed data for visualization

In [None]:
acc_mean_smooth = np.zeros(4,dtype=object)
acc_sem_smooth  = np.zeros(4,dtype=object)
for i in range(4):
    acc_mean_smooth[i] = interpolate.interp1d(times, acc_mean[i,:], kind='linear')
    acc_sem_smooth[i]  = interpolate.interp1d(times, acc_sem[i,:], kind='linear')
    
times_smooth = np.linspace(min(times),max(times),100) # how many time points do you want now?

In [None]:
# general parameters
color1 = 'darkorange'
color2 = 'limegreen'

# start plotting here
fig, axs = plt.subplots(nrows=2,ncols=4,sharex=False,sharey=False,figsize=(26, 2*6))

for ilabel in range(4):
    ################# Plot the full matrix #################      
    # Calculate statistical thresholds
    t_obs, clusters, cluster_pv, h0 = mne.stats.spatio_temporal_cluster_1samp_test((acc[:,ilabel,:,:]-1/4.*100), n_permutations=1024, out_type='mask', n_jobs=20)  
    # format p_values to get same dimensionality as X
    p_values = np.ones_like(acc[0,ilabel,:,:])
    for clu, pval in zip(clusters, cluster_pv):
        p_values[clu] = pval
    mask = p_values < 0.05
    
    ax = axs[0, ilabel]
    im = dsc.plot_contour_image(acc[:,ilabel,:,:].mean(0), times, ax=ax, mask=mask, vmin=25, vmax=35,
                                draw_mask=False, draw_contour=True, colorbar=True,
                                draw_diag=True, draw_zerolines=True,
                                mask_alpha=.5);    
    ax.set_xlabel('Testing Time (s)', fontsize=14)
    ax.set_ylabel('Training Time (s)', fontsize=14)
    ax.set_title('Mean Temporal Generalization Ori%s' %(ilabel*90), fontsize=14)
    ax.axvline(0, color='k')
    ax.axhline(0, color='k')
    ax.tick_params(axis = 'both', which = 'major', direction='in', top=False, right=False, labelsize = 12)
#     plt.colorbar(im, ax=ax)

    ################# plot the diagonal line #################
    ## plot statistical lines
    t_obs_diag, cluster_diag, cluster_pv_diag, H0_diag = mne.stats.permutation_cluster_1samp_test((acc_diag[:,ilabel,:]-1/4.*100), n_permutations=1024, n_jobs=20)
    sig_idx_diag = np.array(np.where(cluster_pv_diag < 0.05)).ravel()
                
    ax = axs[1, ilabel]
    ax.axhline(100/4., color='k', linestyle='--', alpha=0.8, lw=1.5)
    ax.axvline(0, color='k', linestyle='--', alpha=0.8, lw=1.5) 
    
    if ilabel == 0: # because the first peak was not sig. in 0 deg
        idx2 = heapq.nlargest(5, range(len(acc_mean[0])), key=acc_mean[0].__getitem__)[1] # get the index of second largest number
        ax.axvline(times[idx2], color='k', linestyle='-', alpha=0.8, lw=1.5) 
    else:
        ax.axvline(times[np.argmax(acc_mean[ilabel])], color='k', linestyle='-', alpha=0.8, lw=1.5) 

    ax.plot(times_smooth, acc_mean_smooth[ilabel](times_smooth), color=color1, lw=2)
    ax.fill_between(times_smooth, acc_mean_smooth[ilabel](times_smooth)-acc_sem_smooth[ilabel](times_smooth),
                    acc_mean_smooth[ilabel](times_smooth)+acc_sem_smooth[ilabel](times_smooth), color=color1,edgecolor='none',alpha=.35)
    # plot statistical line
    for i in range(len(sig_idx_diag)):
        if len(times[cluster_diag[sig_idx_diag[i]]]) >= 10: # plot at least 10 continous sig. time points
            ax.plot(times[cluster_diag[sig_idx_diag[i]]], np.repeat(20,len(times[cluster_diag[sig_idx_diag[i]]])),'k-', lw=3.5)
            
    ax.set_title('Mean decoding performance Ori%s' %(ilabel*90), fontsize=14)
    ax.set_ylim([18, 42])
    ax.set_xlabel('Time relative to stimulus onset (s)', fontsize=14);
    ax.set_ylabel('Accuracies (%)', fontsize=14);
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis = 'both', which = 'major', direction='in', top=False, right=False, labelsize = 14)

## save figures
save_figure = False
if save_figure:
#     fig_path = project_path + '/data_v5/saved_figures/modelTrain_rs250_ica_allChans/'
    fig_path = project_path + '/data_v5/saved_figures/modelTrain_rs250_ica_occipital/'
    if not os.path.exists(fig_path): os.makedirs(fig_path);
    fig.savefig(fig_path+"modelTrain_acc_mean_find_multiOptimalTime_smoothed.pdf", bbox_inches='tight',dpi=300)

### Save the optimal time points

In [None]:
print(optimal_time_idx)
# save data
save_data = True
if save_data:
    data_path = project_path + 'data_v5/saved_source_data/'
    if not os.path.exists(data_path): os.makedirs(data_path)
    np.save(data_path+'optimal_time_idx_occipital_acc_matrix', optimal_time_idx)