In [1]:
import os
import mne
from mne import find_events
from mne.decoding import Vectorizer, SlidingEstimator, cross_val_multiscore

import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV, cross_val_score, train_test_split, StratifiedKFold, RepeatedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import roc_auc_score, classification_report, accuracy_score, precision_recall_fscore_support, balanced_accuracy_score

import random
import warnings
warnings.filterwarnings('ignore')

import matplotlib
#matplotlib.use('TKAgg')
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
#%matplotlib tk
import pickle
random.seed(42)



#### Data Preparation Functions

In [2]:
def loadData(s_id, sensors, fname, resampled=False):

    if resampled == True:
        epochs = mne.read_epochs(fname, verbose='error')
        print(fname + ' loaded!')
    else:
        epochs = mne.read_epochs(fname, verbose='error')
        resampleData(100, epochs, fname)
    
    return epochs

In [3]:
def resampleData(samplingRate, epochs, filename):
    
    epochs.resample(samplingRate, npad='auto')
    fname = filename[:-4] + '_resampled.fif'
    epochs.save(fname)
    print(fname + ' loaded!')
    return epochs


In [4]:
def extractDataAndLabels(epochs, eventIdsList):
    data, labels = [], []
    
    for l in eventIdsList:
        # Check if given event exist in the data
        if l in epochs.event_id:
            # if it does, then get the data
            epochs_tmp = epochs[l]
            data_tmp = epochs_tmp.get_data()
            #print(len(data_tmp))
            data.append(data_tmp)
            
            # Extract labels: living --> label = 0 / object label = 1
            if 'living' in l:
                labels_tmp = np.zeros(data_tmp.shape[0])
            else:
                labels_tmp = np.ones(data_tmp.shape[0])

            labels.append(labels_tmp)
        else:
            data.append([])
            labels.append([])
    return data, labels


In [5]:
def concatNonEmpty(lists):
    newList = []
    for l in lists:
        if len(l) > 0:
            if len(newList) > 0:
                newList = np.concatenate((newList, l))
            else:
                newList = l
    return newList

In [6]:
# Load confidence data
def loadConfData(confFile):
    conf = np.load(confFile, allow_pickle=True)
    #Convert confidence values to int and None to -1 to ease their use
    for i in range(len(conf)):
        if conf[i] != None:
            if len(conf[i]) > 0:
                conf[i] = int(conf[i][0])
            else:
                conf[i] = 0
        else:
            conf[i] = -1
    return conf

In [18]:
def splitEpochs_byConfidence(confFile, epochs):
    
    # Load confidence ratings
    conf = loadConfData(confFile)

    #Extract the unique confidence values which are not None in data
    conf_values_unique = np.unique([c for c in conf if c > 0])
    print("Unique confidence values: ", conf_values_unique)

    conf = np.array(conf)
    # Extract low confidence trial indices
    conf_low_indices = np.where((conf <= 2) & (conf > 0))[0]
    #print('conf_low_indices: ', conf_low_indices)
    print("Number of low confidence responses: ", len(conf_low_indices))
    
    # Extract high confidence trial indices
    conf_high_indices = np.where(conf >= 3)[0]
    print("Number of high confidence responses: ", len(conf_high_indices))
    
    print('Number of None ( = -1): ', len(np.where(conf == -1)[0]))
    print('Number of no-resp ( = 0): ', len(np.where(conf == 0)[0]))
    print('Total confidence questions: ', len(np.where(conf > -1)[0]))

    # Get low confidence trials
    low_conf_epochs = epochs[conf_low_indices]
    for e in low_conf_epochs.event_id:
        if 'real' in e:
            print('ERROR: real events found in low conf!')
            print(e)
            
        if 'nores' in e:
            print('In low conf ' + str(e) + ' detected: ' + str(len(low_conf_epochs[e].get_data())))
    
    # Get high confidence trials
    high_conf_epochs = epochs[conf_high_indices]
    for e in high_conf_epochs.event_id:
        if 'real' in e:
            print('ERROR: real events found in high conf!')
            print(e)
            
        if 'nores' in e:
            print('In high conf ' + str(e) + ' detected: ' + str(len(high_conf_epochs[e].get_data())))
    
    return low_conf_epochs, high_conf_epochs

In [8]:
def prepareData_conf_behavior(label, epochs, confLevelName):
  
    # Check if the given event label exist in the given data (low conf trials OR high conf trials)
    if label in epochs.event_id: 
        new_epochs = epochs[label]
        data = new_epochs.get_data()
    else:
        print('Warning: Event ( ' + label + ' ) not found in the ' + confLevelName + ' data! Returning empty list!')
        data = []
    
    return data



In [9]:

def prepareData_conf_pred(data_omission_living_lowConf_list, data_omission_living_highConf_list, data_omission_obj_lowConf_list, data_omission_obj_highConf_list):
    
    #Low Confidence
    data_omission_living_lowConf = concatNonEmpty(data_omission_living_lowConf_list)
    labels_omission_living_lowConf = np.zeros(len(data_omission_living_lowConf))
    #print('living low conf: ', len(labels_omission_living_lowConf))

    data_omission_obj_lowConf = concatNonEmpty(data_omission_obj_lowConf_list)
    labels_omission_obj_lowConf = np.ones(len(data_omission_obj_lowConf))
    #print('obj low conf: ', len(labels_omission_obj_lowConf))

    #Combine
    data_omission_lowConf = concatNonEmpty([data_omission_living_lowConf, data_omission_obj_lowConf])
    labels_omission_lowConf = concatNonEmpty([labels_omission_living_lowConf, labels_omission_obj_lowConf])

    #High Confidence

    data_omission_living_highConf = concatNonEmpty(data_omission_living_highConf_list)
    labels_omission_living_highConf = np.zeros(len(data_omission_living_highConf))
    #print('living high conf: ', len(labels_omission_living_highConf))

    data_omission_obj_highConf = concatNonEmpty(data_omission_obj_highConf_list)
    labels_omission_obj_highConf = np.ones(len(data_omission_obj_highConf))
    #print('obj high conf: ', len(labels_omission_obj_highConf))

    #Combine
    data_omission_highConf = concatNonEmpty([data_omission_living_highConf, data_omission_obj_highConf])
    labels_omission_highConf = concatNonEmpty([labels_omission_living_highConf, labels_omission_obj_highConf])

    return data_omission_lowConf, data_omission_highConf, labels_omission_lowConf, labels_omission_highConf  







In [10]:
def prepareData_pred_behavior_conf(living_data_list, obj_data_list ):
    
    #Concatenate living trials
    data_living_all = concatNonEmpty(living_data_list)
    
    #Concatenate object trials
    data_obj_all = concatNonEmpty(obj_data_list)

    # Generate labels
    labels_living_all = np.zeros(len(data_living_all))
    labels_obj_all = np.ones(len(data_obj_all))
    
    # Concatenate living and object trials
    data_all = concatNonEmpty([data_living_all, data_obj_all])
    labels_all = concatNonEmpty([labels_living_all, labels_obj_all])
    
    return data_all, labels_all
    

In [1]:
def loadDataByParticipant(s_id, task_name, fname, sensors, filep):
    
    #Check if data is already resampled!
    fname_resampled = fname[:-4]+'_resampled.fif'
        
    if os.path.isfile(fname_resampled): 
        print('Already resampled data!')
        epochs = loadData(s_id, sensors, fname_resampled, resampled=True)
    else:
        print('Data will be resampled!')
        epochs = loadData(s_id, sensors, fname, resampled=False) 
        

    #Real data and labels
    data_real_living, labels_real_living = extractDataAndLabels(epochs, ['living_real_8', 'living_real_9', 
                                                                         'living_real_10'])
    
    data_real_object, labels_real_object = extractDataAndLabels(epochs, ['object_real_8', 'object_real_9',
                                                                         'object_real_10'])
    #print(len(data_real_living))
    #living omission data and labels
    data_omission_living_corr, labels_omission_living_corr = extractDataAndLabels(epochs, ['living_omission_8_corr',
                                                                                           'living_omission_9_corr',
                                                                                           'living_omission_10_corr'])
    
    data_omission_living_incorr, labels_omission_living_incorr = extractDataAndLabels(epochs,
                                                                                      ['living_omission_8_incorr',
                                                                                       'living_omission_9_incorr',
                                                                                       'living_omission_10_incorr'])

    #object omission data and labels
    data_omission_object_corr, labels_omission_object_corr = extractDataAndLabels(epochs, ['object_omission_8_corr',
                                                                                           'object_omission_9_corr',
                                                                                           'object_omission_10_corr'])
    
    data_omission_object_incorr, labels_omission_object_incorr = extractDataAndLabels(epochs,
                                                                                      ['object_omission_8_incorr',
                                                                                       'object_omission_9_incorr',
                                                                                        'object_omission_10_incorr'])

    #Combine all data
    # ------------------------- Real sounds -------------------------
    # Concatenate all levels for REAL LIVING sounds
    data_real_living_all = concatNonEmpty(data_real_living)
    labels_real_living_all = concatNonEmpty(labels_real_living)
    print('Shape of data real living: ', data_real_living_all.shape)
    
    # Concatenate all levels for REAL OBJECT sounds
    data_real_obj_all = concatNonEmpty(data_real_object)
    labels_real_object_all = concatNonEmpty(labels_real_object)
    
    # Concatenate real LIVING and real OBJECT
    data_real_all = np.concatenate((data_real_living_all, data_real_obj_all))
    labels_real_all = np.concatenate((labels_real_living_all, labels_real_object_all))

    # ------------------------- Omissions -------------------------
    # ----- CORRECT ------
    #Concatenate all levels for omission LIVING and CORRECT
    data_omission_living_all_corr = concatNonEmpty(data_omission_living_corr)
    labels_omission_living_all_corr = concatNonEmpty(labels_omission_living_corr)
    
    #Concatenate all levels for omission OBJECT and CORRECT
    data_omission_obj_all_corr = concatNonEmpty(data_omission_object_corr)
    labels_omission_obj_all_corr = concatNonEmpty(labels_omission_object_corr)

    #Concatenate omission LIVING&CORRECT and OBJECT&CORRECT
    data_omission_all_corr = np.concatenate((data_omission_living_all_corr, data_omission_obj_all_corr))
    labels_omission_all_corr = np.concatenate((labels_omission_living_all_corr, labels_omission_obj_all_corr))


    # ----- INCORRECT ------
    #Concatenate all levels for omission LIVING and INCORRECT
    data_omission_living_all_incorr = concatNonEmpty(data_omission_living_incorr)
    labels_omission_living_all_incorr = concatNonEmpty(labels_omission_living_incorr)

    #Concatenate all levels for omission OBJECT and INCORRECT
    data_omission_obj_all_incorr = concatNonEmpty(data_omission_object_incorr)
    labels_omission_obj_all_incorr = concatNonEmpty(labels_omission_object_incorr)

    #Concatenate omission LIVING&INCORRECT and OBJECT&INCORRECT
    data_omission_all_incorr = np.concatenate((data_omission_living_all_incorr, data_omission_obj_all_incorr))
    labels_omission_all_incorr = np.concatenate((labels_omission_living_all_incorr, labels_omission_obj_all_incorr))


    # ------------------------- 80% -------------------------
    # ------- Real sounds -----
    data_real_8 = concatNonEmpty([data_real_living[0], data_real_object[0]])
    labels_real_8 = concatNonEmpty([labels_real_living[0],labels_real_object[0]])

    #------- Omissions -----
    # ---- CORRECT ----
    data_omission_8_corr = concatNonEmpty([data_omission_living_corr[0], data_omission_object_corr[0]])
    labels_omission_8_corr = concatNonEmpty([labels_omission_living_corr[0], labels_omission_object_corr[0]])

    # ---- INCORRECT ----
    data_omission_8_incorr = concatNonEmpty([data_omission_living_incorr[0], data_omission_object_incorr[0]])
    labels_omission_8_incorr = concatNonEmpty([labels_omission_living_incorr[0], labels_omission_object_incorr[0]])

    # ------------------------- 90% -------------------------
    # ------- Real sounds -----
    data_real_9 = concatNonEmpty([data_real_living[1], data_real_object[1]])
    labels_real_9 = concatNonEmpty([labels_real_living[1],labels_real_object[1]])

    #------- Omissions -----
    # ---- CORRECT ----
    data_omission_9_corr = concatNonEmpty([data_omission_living_corr[1], data_omission_object_corr[1]])
    labels_omission_9_corr = concatNonEmpty([labels_omission_living_corr[1], labels_omission_object_corr[1]])

    # ---- INCORRECT ----
    data_omission_9_incorr = concatNonEmpty([data_omission_living_incorr[1], data_omission_object_incorr[1]])
    labels_omission_9_incorr = concatNonEmpty([labels_omission_living_incorr[1], labels_omission_object_incorr[1]])


    # ------------------------- 100% -------------------------
    # ------- Real sounds -----
    data_real_10 = concatNonEmpty([data_real_living[2], data_real_object[2]])
    labels_real_10 = concatNonEmpty([labels_real_living[2],labels_real_object[2]])

    #------- Omissions -----
    # ---- CORRECT ----
    data_omission_10_corr = concatNonEmpty([data_omission_living_corr[2], data_omission_object_corr[2]])
    labels_omission_10_corr = concatNonEmpty([labels_omission_living_corr[2], labels_omission_object_corr[2]])

    # ---- INCORRECT ----
    data_omission_10_incorr = concatNonEmpty([data_omission_living_incorr[2], data_omission_object_incorr[2]])
    labels_omission_10_incorr = concatNonEmpty([labels_omission_living_incorr[2], labels_omission_object_incorr[2]])

    
    # Real sounds by predictability level and all of them together
    data_real = [data_real_8, data_real_9, data_real_10, data_real_all]
    labels_real = [labels_real_8, labels_real_9, labels_real_10, labels_real_all]
    
    # All omisssions
    data_omissions, labels_omissions = [], []
    
    if task_name == 'all_predLevel':
        # Combine data by predictability lecvel
        data_omissions_8 = concatNonEmpty([data_omission_8_corr, data_omission_8_incorr])
        labels_omissions_8 = concatNonEmpty([labels_omission_8_corr, labels_omission_8_incorr])

        data_omissions_9 = concatNonEmpty([data_omission_9_corr, data_omission_9_incorr])
        labels_omissions_9 = concatNonEmpty([labels_omission_9_corr, labels_omission_9_incorr])

        data_omissions_10 = concatNonEmpty([data_omission_10_corr, data_omission_10_incorr])
        labels_omissions_10 = concatNonEmpty([labels_omission_10_corr, labels_omission_10_incorr])

        data_omissions_all = concatNonEmpty([data_omission_all_corr, data_omission_all_incorr])
        labels_omissions_all = concatNonEmpty([labels_omission_all_corr, labels_omission_all_incorr])

        data_omissions = [data_omissions_8, data_omissions_9, data_omissions_10, data_omissions_all]
        labels_omissions = [labels_omissions_8, labels_omissions_9, labels_omissions_10, labels_omissions_all]
        
    elif task_name == 'all_incorrVScorr':
        print('Number of correct trials: ', len(data_omission_all_corr))
        print('Number of incorrect trials: ', len(data_omission_all_incorr))
        data_omissions = [data_omission_all_corr, data_omission_all_incorr]
        labels_omissions = [labels_omission_all_corr, labels_omission_all_incorr]
        
    elif 'conf' in task_name:
        
        confFile = filep + 'S' + s_id + '\\' + s_id + "_confValues_AR.npy"
        
        #Separate low confidence trials and high confidence trials
        low_conf_epochs, high_conf_epochs = splitEpochs_byConfidence(confFile, epochs)
        print('number of low conf epochs: ', len(low_conf_epochs))
        print('number of high conf epochs: ', len(high_conf_epochs))
        
        #------------------ Low confidence ------------------
        # --------- Living ---------
        data_omission_living_8_corr_lowConf = prepareData_conf_behavior("living_omission_8_corr", 
                                                                        low_conf_epochs, 'low')
        
        data_omission_living_8_incorr_lowConf = prepareData_conf_behavior("living_omission_8_incorr", 
                                                                          low_conf_epochs, 'low')

        data_omission_living_9_corr_lowConf = prepareData_conf_behavior("living_omission_9_corr", 
                                                                        low_conf_epochs, 'low')    
        
        data_omission_living_9_incorr_lowConf = prepareData_conf_behavior("living_omission_9_incorr", 
                                                                          low_conf_epochs, 'low')    

        data_omission_living_10_corr_lowConf = prepareData_conf_behavior("living_omission_10_corr", 
                                                                         low_conf_epochs, 'low')    
        
        data_omission_living_10_incorr_lowConf = prepareData_conf_behavior("living_omission_10_incorr",
                                                                           low_conf_epochs, 'low')    
        
        print('----------------')

        # --------- Object ---------
        data_omission_obj_8_corr_lowConf = prepareData_conf_behavior("object_omission_8_corr",
                                                                     low_conf_epochs, 'low')
        
        data_omission_obj_8_incorr_lowConf = prepareData_conf_behavior("object_omission_8_incorr", 
                                                                       low_conf_epochs, 'low')

        data_omission_obj_9_corr_lowConf = prepareData_conf_behavior("object_omission_9_corr", 
                                                                     low_conf_epochs, 'low')
        
        data_omission_obj_9_incorr_lowConf = prepareData_conf_behavior("object_omission_9_incorr", 
                                                                       low_conf_epochs, 'low')

        data_omission_obj_10_corr_lowConf = prepareData_conf_behavior("object_omission_10_corr",
                                                                      low_conf_epochs, 'low')
        data_omission_obj_10_incorr_lowConf = prepareData_conf_behavior("object_omission_10_incorr",
                                                                        low_conf_epochs, 'low')
        print('----------------')


        #------------------ High confidence ------------------
        # --------- Living ---------
        data_omission_living_8_corr_highConf = prepareData_conf_behavior("living_omission_8_corr",
                                                                         high_conf_epochs, 'high')
        
        data_omission_living_8_incorr_highConf = prepareData_conf_behavior("living_omission_8_incorr", 
                                                                           high_conf_epochs, 'high')

        data_omission_living_9_corr_highConf = prepareData_conf_behavior("living_omission_9_corr",
                                                                         high_conf_epochs, 'high')  
        
        data_omission_living_9_incorr_highConf = prepareData_conf_behavior("living_omission_9_incorr", 
                                                                           high_conf_epochs, 'high')    

        data_omission_living_10_corr_highConf = prepareData_conf_behavior("living_omission_10_corr",
                                                                          high_conf_epochs, 'high')    
        
        data_omission_living_10_incorr_highConf = prepareData_conf_behavior("living_omission_10_incorr",
                                                                            high_conf_epochs, 'high')    

        print('----------------')
        
        # --------- Object ---------
        data_omission_obj_8_corr_highConf = prepareData_conf_behavior("object_omission_8_corr",
                                                                      high_conf_epochs, 'high')
        
        data_omission_obj_8_incorr_highConf = prepareData_conf_behavior("object_omission_8_incorr",
                                                                        high_conf_epochs, 'high')

        data_omission_obj_9_corr_highConf = prepareData_conf_behavior("object_omission_9_corr",
                                                                      high_conf_epochs, 'high')
        
        data_omission_obj_9_incorr_highConf = prepareData_conf_behavior("object_omission_9_incorr",
                                                                        high_conf_epochs, 'high')

        data_omission_obj_10_corr_highConf = prepareData_conf_behavior("object_omission_10_corr",
                                                                       high_conf_epochs, 'high')
        
        data_omission_obj_10_incorr_highConf = prepareData_conf_behavior("object_omission_10_incorr",
                                                                         high_conf_epochs, 'high')
        
        
        
        if task_name == 'behavior_conf_pred': # Split data based on all 3 experimental conditions
            print('!! behavior_conf_pred !!')
            # ------- 80% & CORRECT & LOW confidence -------
            data_omission_8_corr_lowConf, labels_omission_8_corr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_lowConf], [data_omission_obj_8_corr_lowConf])    
            
            # ------- 90% & CORRECT & LOW confidence -------
            data_omission_9_corr_lowConf, labels_omission_9_corr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_corr_lowConf], [data_omission_obj_9_corr_lowConf])  
            
            # ------- 100% & CORRECT & LOW confidence -------
            data_omission_10_corr_lowConf, labels_omission_10_corr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_corr_lowConf], [data_omission_obj_10_corr_lowConf]) 
            
            # =====================================
            
            # ------- 80% & INCORRECT & LOW confidence -------
            data_omission_8_incorr_lowConf, labels_omission_8_incorr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_incorr_lowConf], [data_omission_obj_8_incorr_lowConf])    
            
            # ------- 90% & INCORRECT & LOW confidence -------
            data_omission_9_incorr_lowConf, labels_omission_9_incorr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_incorr_lowConf], [data_omission_obj_9_incorr_lowConf])  
            
            # ------- 100% & INCORRECT & LOW confidence -------
            data_omission_10_incorr_lowConf, labels_omission_10_incorr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_incorr_lowConf], [data_omission_obj_10_incorr_lowConf]) 
            
            # =====================================
            
            # ------- 80% & CORRECT & HIGH confidence -------
            data_omission_8_corr_highConf, labels_omission_8_corr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_highConf], [data_omission_obj_8_corr_highConf])    
            
            # ------- 90% & CORRECT & HIGH confidence -------
            data_omission_9_corr_highConf, labels_omission_9_corr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_corr_highConf], [data_omission_obj_9_corr_highConf])  
            
            # ------- 100% & CORRECT & HIGH confidence -------
            data_omission_10_corr_highConf, labels_omission_10_corr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_corr_highConf], [data_omission_obj_10_corr_highConf]) 
            
            # =====================================
            
            # ------- 80% & INCORRECT & HIGH confidence -------
            data_omission_8_incorr_highConf, labels_omission_8_incorr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_incorr_highConf], [data_omission_obj_8_incorr_highConf])    
            
            # ------- 90% & INCORRECT & HIGH confidence -------
            data_omission_9_incorr_highConf, labels_omission_9_incorr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_incorr_highConf], [data_omission_obj_9_incorr_highConf])  
            
            # ------- 100% & INCORRECT & HIGH confidence -------
            data_omission_10_incorr_highConf, labels_omission_10_incorr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_incorr_highConf], [data_omission_obj_10_incorr_highConf]) 
            
            # Data to return: 
            # 80 & C & L, 90 & C & L, 100 & C & L
            # 80 & I & L, 90 & I & L, 100 & I & L
            # 80 & C & H, 90 & C & L, 100 & C & H
            # 80 & I & H, 90 & I & H, 100 & I & H
            data_omissions = [data_omission_8_corr_lowConf, data_omission_9_corr_lowConf,
                              data_omission_10_corr_lowConf,
                             data_omission_8_incorr_lowConf, data_omission_9_incorr_lowConf,
                              data_omission_10_incorr_lowConf,
                             data_omission_8_corr_highConf, data_omission_9_corr_highConf,
                              data_omission_10_corr_highConf,
                             data_omission_8_incorr_highConf, data_omission_9_incorr_highConf,
                              data_omission_10_incorr_highConf]
            
            
            labels_omissions = [labels_omission_8_corr_lowConf, labels_omission_9_corr_lowConf,
                                labels_omission_10_corr_lowConf,
                               labels_omission_8_incorr_lowConf, labels_omission_9_incorr_lowConf,
                                labels_omission_10_incorr_lowConf,
                               labels_omission_8_corr_highConf, labels_omission_9_corr_highConf,
                                labels_omission_10_corr_highConf,
                               labels_omission_8_incorr_highConf, labels_omission_9_incorr_highConf,
                                labels_omission_10_incorr_highConf]
            
            
        elif task_name == 'conf_pred': # Split data based on predictability and confidence
            print('Data is getting prepared for the task: ', task_name)
            # 80% & LOW confidence
            data_omission_8_lowConf, labels_omission_8_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_lowConf, data_omission_living_8_incorr_lowConf], 
                [data_omission_obj_8_corr_lowConf, data_omission_obj_8_incorr_lowConf])    
            
            # ------- 90% & CORRECT & LOW confidence -------
            data_omission_9_lowConf, labels_omission_9_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_corr_lowConf, data_omission_living_9_incorr_lowConf],
                [data_omission_obj_9_corr_lowConf, data_omission_obj_9_incorr_lowConf])  
            
            # ------- 100% & CORRECT & LOW confidence -------
            data_omission_10_lowConf, labels_omission_10_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_corr_lowConf, data_omission_living_10_incorr_lowConf],
                [data_omission_obj_10_corr_lowConf, data_omission_obj_10_incorr_lowConf]) 
            

            # ------- 80% & HIGH confidence -------
            data_omission_8_highConf, labels_omission_8_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_highConf, data_omission_living_8_incorr_highConf],
                [data_omission_obj_8_corr_highConf, data_omission_obj_8_incorr_highConf])    
            
            # ------- 90% & CORRECT & HIGH confidence -------
            data_omission_9_highConf, labels_omission_9_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_9_corr_highConf, data_omission_living_9_incorr_highConf],
                [data_omission_obj_9_corr_highConf, data_omission_obj_9_incorr_highConf])  
            
            # ------- 100% & CORRECT & HIGH confidence -------
            data_omission_10_highConf, labels_omission_10_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_10_corr_highConf, data_omission_living_10_incorr_highConf],
                [data_omission_obj_10_corr_highConf, data_omission_obj_10_incorr_highConf]) 
            
            
            # Data to return: 
            # 80 & L, 90 & L, 100 & L,
            # 80 & H, 90 & H, 100 & H
            data_omissions = [data_omission_8_lowConf, data_omission_9_lowConf, data_omission_10_lowConf,
                             data_omission_8_highConf, data_omission_9_highConf, data_omission_10_highConf]
            
            
            labels_omissions = [labels_omission_8_lowConf, labels_omission_9_lowConf, labels_omission_10_lowConf,
                               labels_omission_8_highConf, labels_omission_9_highConf, labels_omission_10_highConf]
            
            
        else: # Use confidence only OR behavior and confidence to split data (do not split by predictability level)
            
            #--------------------------- Omissions ---------------------------

            # --------------- Correct -------------
            # ---- Low Confidence ----    
            data_omission_all_corr_lowConf, labels_omission_all_corr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_lowConf, data_omission_living_9_corr_lowConf,
                 data_omission_living_10_corr_lowConf],
                [data_omission_obj_8_corr_lowConf, data_omission_obj_9_corr_lowConf,
                 data_omission_obj_10_corr_lowConf])
            
            # ---- High Confidence ----
            data_omission_all_corr_highConf, labels_omission_all_corr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_corr_highConf, data_omission_living_9_corr_highConf,
                 data_omission_living_10_corr_highConf],
                [data_omission_obj_8_corr_highConf, data_omission_obj_9_corr_highConf,
                 data_omission_obj_10_corr_highConf])

            # --------------- Incorrect -------------
            # ---- Low Confidence ----
            data_omission_all_incorr_lowConf, labels_omission_all_incorr_lowConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_incorr_lowConf, data_omission_living_9_incorr_lowConf, 
                 data_omission_living_10_incorr_lowConf],
                [data_omission_obj_8_incorr_lowConf, data_omission_obj_9_incorr_lowConf, 
                 data_omission_obj_10_incorr_lowConf])   
    
            # ---- High Confidence ----
            data_omission_all_incorr_highConf, labels_omission_all_incorr_highConf = prepareData_pred_behavior_conf(
                [data_omission_living_8_incorr_highConf, data_omission_living_9_incorr_highConf,
                 data_omission_living_10_incorr_highConf],
                [data_omission_obj_8_incorr_highConf, data_omission_obj_9_incorr_highConf,
                 data_omission_obj_10_incorr_highConf])

            if task_name == "behavior_conf": # C & L --- C & H --- I & L --- I & H
                print('confidence and behaviour!')
                data_omissions = [data_omission_all_corr_lowConf, data_omission_all_corr_highConf, 
                                  data_omission_all_incorr_lowConf, data_omission_all_incorr_highConf]

                labels_omissions = [labels_omission_all_corr_lowConf, labels_omission_all_corr_highConf,
                                   labels_omission_all_incorr_lowConf, labels_omission_all_incorr_highConf]

            elif task_name == 'all_conf': # splitting just based on CONFIDENCE
                print('Only confidence!')
                
                # ---- lOW CONFIDENCE ----
                data_omission_lowConf_all = concatNonEmpty([data_omission_all_corr_lowConf,
                                                            data_omission_all_incorr_lowConf])
                
                labels_omission_lowConf_all = concatNonEmpty([labels_omission_all_corr_lowConf,
                                                              labels_omission_all_incorr_lowConf])

                # ---- HIGH CONFIDENCE ----
                data_omission_highConf_all = concatNonEmpty([data_omission_all_corr_highConf,
                                                             data_omission_all_incorr_highConf]) 
                
                labels_omission_highConf_all = concatNonEmpty([labels_omission_all_corr_highConf,
                                                               labels_omission_all_incorr_highConf])                                                                                                         

                data_omissions = [data_omission_lowConf_all, data_omission_highConf_all]
                labels_omissions = [labels_omission_lowConf_all, labels_omission_highConf_all]


    
    return epochs, data_real, labels_real, data_omissions, labels_omissions





#### MVPA functions

In [12]:
def getBestParams(filename):
    params, words = [], []
    
    with open(filename,'r') as file: 
        for line in file:         
            words.append(line.split())
    
    for i in range(len(words)):
        params.append(words[i][-1])
    
    return params

In [4]:
def trainAndTest_MVPA(data_real, labels_real, test_data, test_labels, outputfilename, bestParametersFile,
                      tlim, modelsFile, nFolds=5, bestParamsFound=False):
        
    train_data_real, test_data_real, train_labels_real, test_labels_real = train_test_split(data_real, labels_real, 
                                                                                            test_size=0.25,
                                                                                            random_state=42,
                                                                                            stratify=labels_real)  
    repeats = nFolds
    rkf = RepeatedKFold(n_splits=nFolds, n_repeats=repeats, random_state=42)
    
    '''
    if bestParamsFound == False:
        # 2. make optimization pipeline:
        
        parameters = {'penalty': ['l1', 'l2']}
        clf_opt = make_pipeline(Vectorizer(), StandardScaler(), GridSearchCV(LogisticRegression(),parameters,cv=rkf))

        # 3. Use CV data to fit and optimize our classifier:
        clf_opt.fit(train_data_real, train_labels_real)
        # 4. retrieve optimal parameters:
        tmp = clf_opt.steps[-1][1]       
        best_c = tmp.best_params_['C']
        best_penalty = tmp.best_params_['penalty']
        # 5. Use the optimized classifier on the test dataset (w/o time):
        score = clf_opt.score(test_data_real, test_labels_real)

        print(score)
        print('best penalty: ', best_penalty)
        print('best C: ', best_c)
        #save the best params for later use
        file = open(bestParametersFile,"w") 
        file.writelines('best C: ' + str(best_c))
        file.writelines('\nbest penalty: ' + str(best_penalty))
        file.close()
    else:
        [best_penalty] = getBestParams(bestParametersFile)
    '''
    
    clf_tp = []
    CV_score = np.zeros((repeats*nFolds, tlim))
    Test_score_real = np.zeros((tlim))
    Test_score_omissions = [np.zeros((tlim)) for i in range(len(test_data))]
    clf_list = []
    
    for tp in np.arange(tlim):
        #print(tp)

        d2t_cv = train_data_real[:,:,tp] # data to test
        d2t_test = test_data_real[:,:,tp] # data to test - real
        
        print('Penalty: l1')
        clf = make_pipeline(StandardScaler(), LogisticRegression( solver='liblinear', penalty='l1')) #best_penalty
        # get CV score:
        CV_score[:,tp] = cross_val_score(clf, d2t_cv, train_labels_real, cv=rkf, scoring='roc_auc')        

        # fit the model using all CV data:
        clf.fit(d2t_cv, train_labels_real)
        clf_list.append(clf)
        # generalize performance on test data:
        labels_test_estim= clf.predict(d2t_test)
        
        Test_score_real[tp] = roc_auc_score(test_labels_real,labels_test_estim)
        
        for i in range(len(test_data)):
            if len(test_data[i]) > 0:
                d2t_test_omissions = test_data[i][:,:,tp] # data to test - omissions_corr lowConf
                labels_test_estim_omissions = clf.predict(d2t_test_omissions)
                try:
                    Test_score_omissions[i][tp] = roc_auc_score(test_labels[i], labels_test_estim_omissions)
                except ValueError:
                    pass
            else:
                Test_score_omissions[i][tp] = 0

    results = [Test_score_real]
    for i in range(len(test_data)):
        results.append(Test_score_omissions[i])
    results.append(CV_score)
    np.save(outputfilename, results)
    
    '''  
    with open(modelsFile, "wb") as f:
        for model in clf_list:
            pickle.dump(model, f)
    
    ''' 
    return results





#### Plotting functions

In [2]:
def plot_MVPA(results, times, tlim, plotname, fullTrial=True, isBehavior=True):
    
    Test_score_real = results[0]
    CV_score = results[-1]
    if len(results) == 3:
        Test_score_omissions = results[1]
    elif len(results) == 4:
        Test_score_omissions = [results[1], results[2]]
    elif len(results) == 5:
        Test_score_omissions = [results[1], results[2], results[3]]
    #[Test_score_real, Test_score_omissions, CV_score] = results
    if fullTrial == False:
        end_of_omission = np.where(times == 0.3)[0][0]
        times_omi = times[:end_of_omission+1]
    else:
        times_omi = times
        end_of_omission = len(times)-1
    
    fig = plt.figure(num=None, figsize=(8, 2), dpi=150)
    plt.subplot(1,2,1)
    ax = plt.plot(times_omi, Test_score_real[:end_of_omission+1],label = 'Test Real')
    
    if len(results) == 4:
        if isBehavior == True:
            labels = ['Test Omissions_correct', 'Test Omissions_incorrect']
        else:
            labels = ['Test Omissions_lowConf', 'Test Omissions_highConf']
        

        ax = plt.plot(times_omi, Test_score_omissions[0][:end_of_omission+1],label = labels[0])
        ax = plt.plot(times_omi, Test_score_omissions[1][:end_of_omission+1],label = labels[1])
    
    elif len(results) == 5:
        labels = ['Test Omissions - 80%', 'Test Omissions - 90%', 'Test Omissions - 100%']
        for i in range(len(Test_score_omissions)):
            ax = plt.plot(times_omi, Test_score_omissions[i][:end_of_omission+1],label = labels[i])
    
    else:
        ax = plt.plot(times_omi, Test_score_omissions[:tlim],label = 'Test Omissions')
        
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.3))

    plt.title('Test set')
    plt.xlabel('Time (s)')
    plt.ylabel('Accuracy')

    plt.subplot(1,2,2)
    ax = plt.plot(times, np.nanmean(CV_score, axis = 0)[:tlim],label = 'CV Real')
    plt.title('Cross-validation set')
    plt.xlabel('Time (s)')
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.3))

    fig1 = plt.gcf()
    fig1.savefig(plotname, bbox_inches='tight')
    plt.show()

In [15]:
def plot_MVPA_Group(results, labels, plotname):
    print(len(results))
    nrow = int(len(results)/2)-1
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 6), dpi=150)
    
    print(axs.shape)
    #print(len(axs[0]))
    if nrow > 1:
        for i in range(nrow):

            if i == 0:
                axs[i, 0].plot(epochs.times[:tlim], np.nanmean(results[0], axis = 0)[:tlim],label = labels[0])
                print(results[i][:tlim])
            else:
                axs[i, 0].plot(epochs.times[:tlim], results[i][:tlim],label = labels[i])
                axs[i, 1].plot(epochs.times[:tlim], results[i*loop_size+2][:tlim],label = labels[i*loop_size+2])
                axs[i, 1].plot(epochs.times[:tlim], results[i*loop_size+3][:tlim],label = labels[i*loop_size+3])
    
    elif nrow == 1:
        axs[0].plot(epochs.times[:tlim], np.nanmean(results[0], axis = 0)[:tlim],label = labels[0])
        axs[1].plot(epochs.times[:tlim], results[1][:tlim],label = labels[1])
        axs[1].plot(epochs.times[:tlim], results[2][:tlim],label = labels[2])
        axs[1].plot(epochs.times[:tlim], results[3],label = labels[3])
       
    
    for i in range(len(axs.flat)):
            if i == 0:
                axs.flat[i].set_title('Cross-validation set')
            else:
                axs.flat[i].set_title('Test set')
            axs.flat[i].set(xlabel='Time (s)', ylabel='Accuracy')
            axs.flat[i].legend(loc='upper center', bbox_to_anchor=(0.1, -0.3))
  
    
    #plt.xlabel('Time (s)')
    
    fig.tight_layout()
    fig1 = plt.gcf()
    fig1.savefig(plotname, bbox_inches='tight')
    plt.show()