In [None]:
import sys, os, glob, warnings, time, datetime
from math import sqrt
import numpy as np
import scipy, mne, random
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import io, stats, interpolate
%matplotlib inline
from sklearn.pipeline import make_pipeline
from sklearn.svm import LinearSVC, SVC
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import LeaveOneOut, cross_val_score, KFold, cross_val_predict
from sklearn.preprocessing import StandardScaler, scale
from scipy.signal import savgol_filter

warnings.filterwarnings('ignore')

In [None]:
project_path = 'D:/Dropbox/Projects/featureReplay/'    
SUBJECTS = ['S01','S02','S03','S04','S05','S06','S07','S08','S09','S10',
            'S11','S12','S13','S14','S15','S16','S17','S18','S19','S20',
            'S21','S22','S23','S24','S25','S26','S27','S28','S29','S30',
            'S31','S32','S33','S34','S35']

# !! Important: if you only select one subject, you must still write as [n-1:n]
selected_subj = SUBJECTS[:18]
n_subjects    = len(selected_subj)
print(['Running Subjects:'] + selected_subj)

# Get selected channel index (occipital)
chans_all       = np.loadtxt(project_path + 'data_v5/misc_data/channels_all.txt', dtype='str')
chans_occipital = np.loadtxt(project_path + 'data_v5/misc_data/occipital_channels.txt', dtype='str')
chans_idx       = np.where(np.in1d(chans_all, chans_occipital) == True)[0]

# params for decoding
clf = make_pipeline(StandardScaler(), LogisticRegression(C=0.5, penalty='l2', max_iter=10000, class_weight='balanced', 
                                                         multi_class='ovr', solver='liblinear'))
proba = np.zeros((n_subjects, 3, 775, 96, 4)) # n_subj, n_conditions, n_times, n_maxTrials(96), n_orientations
proba[:] = np.nan

## Run decoding

In [None]:
# average trials before training and test
for s in range(n_subjects):
    start = time.time()
    subj_id = selected_subj[s]
    print('>>> loading data for subject:', subj_id)
    
    ## ================ load train data (modelTrain) =============== ##
    ## load epochs data of orientation period
    cond_data = project_path + 'data_v5/%s/%s_modelTrain_epochs_all_resample250_ica-epo.fif' %(subj_id, subj_id)
    epochs_all = mne.read_epochs(cond_data, preload=True)
    X_ori = epochs_all.get_data() # n_trials * n_channels * n_times
    y_ori = epochs_all.events[:,2]
  
    ## load epochs data of ITI period
    iti_data = project_path + 'data_v5/%s/%s_modelTrain_epochs_all_resample250_ica_ITI-epo.fif' %(subj_id, subj_id)
    epochs_iti = mne.read_epochs(iti_data, preload=True)
    X_iti = epochs_iti.get_data() # n_trials * n_channels * n_times
    y_iti = np.zeros(X_iti.shape[0]) # all iti period labels are 0

    ## concatenate orientation data and iti data
    X_train = np.vstack((X_ori, X_iti))
    y_train = np.hstack((y_ori, y_iti)) 

    ## select relevant channels (occipital cortex)
    X_train = np.mean(X_train[:,chans_idx,100:200],axis=2)

    repeat_num = 10
    X_train_perm = np.zeros(repeat_num, dtype=object)
    y_train_perm = np.zeros(repeat_num, dtype=object)
    for it in range(repeat_num): # 
        shuffle_order = np.arange(len(y_train))
        np.random.shuffle(shuffle_order) # shuffle the row of the array
        y_train_shuffle = y_train[shuffle_order]
        X_train_shuffle = X_train[shuffle_order,:]

        ## average relevant trials before decoding
        n_classes = 5
        X_per_class = np.zeros(n_classes, dtype=object)
        t_num_per_class = np.zeros(n_classes)
        for ic in range(n_classes): # 5 classes, including iti period (0)
            class_idx = np.where(y_train_shuffle == ic)[0]
            X_this_class = X_train_shuffle[class_idx] # n_trials x n_features
            everyNrows = 2 # average every N trials 5   
            
            if len(class_idx)%everyNrows > 0:
                rows = len(class_idx)//everyNrows
                reminder = len(class_idx)%everyNrows
                X_tmp = X_this_class[:(len(class_idx)-reminder)]
                X_tmp2 = X_tmp.transpose().reshape(-1,everyNrows).mean(1).reshape(X_tmp.shape[1],-1).transpose()
                X_tmp3 = np.concatenate((X_tmp2, X_this_class[-reminder:]))            
            else:
                X_tmp3 = X_this_class.transpose().reshape(-1,everyNrows).mean(1).reshape(X_this_class.shape[1],-1).transpose()
            
            X_per_class[ic] = X_tmp3
            t_num_per_class[ic] = X_tmp3.shape[0]
    
        X_train_perm[it] = np.concatenate(X_per_class)
        y_train_perm[it] = np.arange(n_classes).repeat(t_num_per_class.astype(int))

    X_train_aver = np.mean(X_train_perm)
    y_train_aver = y_train_perm[0]

    ## ================ load test data (mainPost) =============== ##
    cond_data = project_path + 'data_v5/%s/%s_mainPost_epochs_all_resample250_ica-epo.fif' %(subj_id, subj_id)
    epochs_all = mne.read_epochs(cond_data, preload=True)
    X_test_all = epochs_all.get_data() # n_trials * n_channels * n_times
    y_test_all = epochs_all.events[:,2]

    ## select relevant channels
    X_test_all = X_test_all[:,chans_idx,:] # select occipital cortex

    # find the true orientation order for the current subject    
    f_mat = scipy.io.loadmat(project_path + 'data_v5/behavioral_data/%s/MainTask/params_PostTest_%s_R01.mat' %(subj_id, subj_id))
    test_dir = f_mat['p']['Orient'][0,0].ravel()
    test_dir_idx = (test_dir+90)/90 # orientation order, e.g., 1 2 3 4
    
    # select relevant trials for each condition
    selection = np.zeros(len(y_test_all))
    selection[y_test_all == test_dir_idx[0]] += 1       # full sequence condition
    selection[y_test_all == (test_dir_idx[0]+10)] += 2  # Start condition
    selection[y_test_all == (test_dir_idx[3]+10)] += 3  # End condition
    
    ## ================ do classification from here ================ ##       
    clf.fit(X_train_aver, y_train_aver) 
    
    ## apply trained model to each condition and each timepoint
    for t in range(X_test_all.shape[-1]): # the last dimension is times
        X_test_this_time = X_test_all[:,:,t]

        for ic in range(3): # 3 conditions              
            idx = np.where(selection == (ic+1))[0]
            X_test_cond = X_test_this_time[idx,:]

            repeat_num = 30
            X_test_cond_aver = np.zeros(repeat_num, dtype=object)
            for ita in range(repeat_num):
                np.random.shuffle(X_test_cond)
                ## average relevant trials before decoding
                everyNrows = 5 # average every N trials
                trialNum = X_test_cond.shape[0]
                if trialNum%everyNrows > 0:
                    rows = trialNum//everyNrows
                    reminder = trialNum%everyNrows
                    X_tmp = X_test_cond[:(trialNum-reminder)]
                    X_tmp2 = X_tmp.transpose().reshape(-1,everyNrows).mean(1).reshape(X_tmp.shape[1],-1).transpose()
                    X_test_cond_aver[ita] = np.concatenate((X_tmp2, np.mean(X_test_cond[-reminder:])))            
                else:
                    X_test_cond_aver[ita] = X_test_cond.transpose().reshape(-1,everyNrows).mean(1).reshape(X_test_cond.shape[1],-1).transpose()

            X_test_cond_shuff = np.mean(X_test_cond_aver)  
            # predict main post data
            y_pred = clf.predict_proba(X_test_cond_shuff) 
            
            # store probabilities for each trial
            proba[s,ic,t,:len(y_pred),:] = y_pred[:,1:][:,test_dir_idx.astype(int)-1]
            
    # # check spent duration
    # print(subj_id, " took ", str(datetime.timedelta(seconds=time.time()-start)))

## save data
data_path = project_path + 'data_v5/saved_source_data/'
if not os.path.exists(data_path): os.makedirs(data_path)
np.save(data_path+'probability_mainPost_occipital_double_mean', proba)

## Loading data for plotting

In [None]:
# pseudocolor plot
data_path = project_path + 'data_v5/saved_source_data/'
proba = np.load(data_path + 'probability_mainPost_occipital_double_mean.npy', allow_pickle=True)
proba_mean = np.nanmean(proba, axis=(0,3)) # n_conds, n_times, n_classes

# start plotting here
conditions = ['Full sequence', 'Start only', 'End only']
fig, axs = plt.subplots(nrows=1,ncols=3,sharex=True,sharey=True,figsize=(18, 5))
plt.rcParams["font.family"] = "arial"

for icond in range(3):
    ax = axs[icond] 

    y_filtered = savgol_filter(proba_mean[icond], window_length=51, polyorder=1, axis=0)

    im = ax.pcolor(y_filtered)
    ax.set_xlabel('Orientation', fontsize=14)
    ax.set_ylabel('Time (s)', fontsize=14)
    ax.set_xticks(np.arange(4)+.5)
    ax.set_xticklabels(np.arange(4)+1)
    ax.set_yticks([0, 75, 200, 325, 450, 575, 700])
    ax.set_yticklabels([-0.3, 0, 0.5, 1, 1.5, 2, 2.5])
    ax.set_title('%s' %(conditions[icond]), fontsize=16)
    ax.tick_params(axis = 'both', which = 'major', direction='in', top=False, right=False, labelsize = 14)
    ax.invert_yaxis()
    cbar = plt.colorbar(im, ax=ax)
    cbar.ax.tick_params(labelsize=14)
    cbar.ax.get_yaxis().labelpad = 13
    cbar.ax.set_ylabel('Decoding probability', rotation=90, fontsize=16)
    
## save figures
save_figure = False
if save_figure:
    fig_path = project_path + '/data_v5/saved_figures/mainPost_proba_rs250_ica_occipital/'
    if not os.path.exists(fig_path): os.makedirs(fig_path);
    fig.savefig(fig_path+"mainPost_probability_meanOverTrials_pcolor.pdf", bbox_inches='tight',dpi=300)

In [None]:
# plot the raw data curve
data_path = project_path + 'data_v5/saved_source_data/'
proba = np.load(data_path + 'probability_mainPost_occipital_double_mean.npy', allow_pickle=True)
proba_mean = np.nanmean(proba, axis=(0,3)) # n_conds, n_times, n_classes

# start plotting here
conditions = ['Full sequence', 'Start-only', 'End-only']
fig, axs = plt.subplots(nrows=1,ncols=3,sharex=False,sharey=False,figsize=(18, 5))
plt.rcParams["font.family"] = "arial"

for i in range(3): 
    ax = axs[i]

    for j in range(4): # predicted label probability
        y_filtered = proba_mean[i,:,j]

        if i == 0: # plot legend in fig.1 only
            ax.plot(y_filtered, lw=1.5, label='%s$^{\circ}$' %(j*90))
            ax.legend(loc=1,fontsize=14,frameon=False)
        else:
            ax.plot(y_filtered, lw=1.5)
    
    # plot highlight area
    for low, high in zip([75, 250, 425, 600], [175, 350, 525, 700]):
        ax.axvspan(low, high, facecolor='gray', alpha=.2)
    # set parameters for figures
    ax.set_title(conditions[i], fontsize=12)
    ax.set_xlabel('Time (s)', fontsize=12)
    ax.set_ylabel('Decoding probability (%)', fontsize=12)
    ax.set_xticks((75, 175, 250, 350, 425, 525, 600, 700))
    ax.set_xticklabels((0,.4, .7,1.1, 1.4,1.8, 2.1,2.5))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis='both', which='both', direction='out', labelsize = 12)

## save figures
save_figure = False
if save_figure:
    fig_path = project_path + '/data_v5/saved_figures/mainPost_proba_rs250_ica_occipital/'
    if not os.path.exists(fig_path): os.makedirs(fig_path);
    fig.savefig(fig_path+"mainPost_probability_meanOverTrials_curve_raw.pdf", bbox_inches='tight',dpi=300)

### Smooth the curve for visualization

In [None]:
times = np.arange(775)
data_path = project_path + 'data_v5/saved_source_data/'
proba = np.load(data_path + 'probability_mainPost_occipital_double_mean.npy', allow_pickle=True)
proba_mean = np.nanmean(proba, axis=(0,3)) # n_conds, n_times, n_classes
proba_sem = stats.sem(np.nanmean(proba, axis=3), axis=0) # n_conds, n_times, n_classes

# start plotting here
conditions = ['Full sequence', 'Start-only', 'End-only']
fig, axs = plt.subplots(nrows=1,ncols=3,sharex=False,sharey=False,figsize=(18, 5))
plt.rcParams["font.family"] = "arial"

for i in range(3): 
    ax = axs[i]
    
    for j in range(4): # predicted label probability
        y_filtered = savgol_filter(proba_mean[i,:,j], window_length=51, polyorder=1, mode='interp')

        if i == 2: # plot legend in fig.1 only
            ax.plot(y_filtered, lw=2, label='%s$^{\circ}$' %(j*90))
            ax.legend(loc=1,fontsize=12,frameon=False)
        else:
            ax.plot(y_filtered, lw=2)
        ax.fill_between(times, y_filtered-proba_sem[i,:,j], y_filtered+proba_sem[i,:,j], edgecolor='none',alpha=.3)

    # plot highlight area
    for low, high in zip([75, 250, 425, 600], [175, 350, 525, 700]):
        ax.axvspan(low, high, facecolor='gray', alpha=.2)
    # set parameters for figures
    ax.set_title(conditions[i], fontsize=12)
    ax.set_xlabel('Time relative to trial onset (s)', fontsize=12)
    ax.set_ylabel('Decoding probability (%)', fontsize=12)
    ax.set_xticks((75, 175, 250, 350, 425, 525, 600, 700))
    ax.set_xticklabels((0,.4, .7,1.1, 1.4,1.8, 2.1,2.5))
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.tick_params(axis='both', which='both', direction='out', labelsize = 12)
    ax.set_ylim(0,0.37)

## save figures
save_figure = False
if save_figure:
    fig_path = project_path + '/data_v5/saved_figures/mainPost_proba_rs250_ica_occipital/'
    if not os.path.exists(fig_path): os.makedirs(fig_path);
    fig.savefig(fig_path+"mainPost_probability_meanOverTrials_curve_smooth.pdf", bbox_inches='tight',dpi=300)