In [10]:
import numpy as np
import pandas as pd
import glob
import time
from scipy import signal
from tqdm import tqdm

In [11]:
train_files = glob.glob('data\Train\Data*.csv')
test_files = glob.glob('data\Test\Data*.csv')
print(train_files[:5])

['data\\Train\\Data_S02_Sess01.csv', 'data\\Train\\Data_S02_Sess02.csv', 'data\\Train\\Data_S02_Sess03.csv', 'data\\Train\\Data_S02_Sess04.csv', 'data\\Train\\Data_S02_Sess05.csv']


In [16]:
def butter_filter(order, low_pass, high_pass, fs,sig):
    nyq = 0.5 * fs
    lp = low_pass / nyq
    hp = high_pass / nyq
    sos = signal.butter(order, [lp, hp], btype='band', output = 'sos')
    return signal.sosfilt(sos, sig)



def extract_d(files, e_s = None, baseline = True, bandpass = True):
    start = time.time()
    
    training_subjects = 16 #num of training subjects
    num_of_fb = 340 #num of feedbacks / subject
    freq = 200 #sampling rate
    epoch_time = 0.5 #proposed epoching time in seconds
    epoch = int(freq * epoch_time) #epoch in indices 
    #epoch_s = int(freq * e_s)
    num_of_cols = 59
    eeg_cols = 56
    b_s = int(-0.4*freq) #index where baseline starts relative to feedback (-400ms)
    b_e = int(-0.3*freq) #index where baseline ends relative to feedback (-300ms)
    order = 5 #butterworth order
    low_pass = 1 #low frequency pass for butterworth filter
    high_pass = 40 #high frequency pass for butterworth filter
    
    channels = ['Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3', 'F1',
       'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
       'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2',
       'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4',
       'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8',
       'PO7', 'POz', 'P08', 'O1', 'O2']
    
    temp = np.empty((1,epoch,num_of_cols), float)
    for i, f in enumerate(files):
        print(i,f, temp.shape)
        df = pd.read_csv(f) #read each file
        index_fb = df[df['FeedBackEvent'] == 1].index.values
        df_array = np.array(df) 
        
        #uncomment below for butterworth filter
        if bandpass == True:
            eeg = df_array[:,1:57] #only eeg values to apply butterworth filter
            for i, channel in enumerate(channels):
                raw_eeg = df[channel].values
                eeg_filtered = butter_filter(order, low_pass, high_pass, freq, raw_eeg) #butterworth filter applied
                eeg[:,i] = eeg_filtered
            df = np.array(df)
            df[:,1:57] = eeg #replacing old eeg values with new ones
        else:
            df = np.array(df)
        
        for j, indx in enumerate(index_fb): #epoching 100 indexes (0.5 seconds) after each stimulus
            if e_s != None:
                epoch_array = df[indx:(indx+int(epoch)),:]
                epoch_array = epoch_array.reshape((1,int(epoch),int(epoch_array.shape[1])))
            else:
                epoch_array = df[indx:(indx+int(epoch)),:]
                epoch_array = epoch_array.reshape((1,int(epoch),int(epoch_array.shape[1])))

            #uncomment below for baseline correction
            if baseline == True:
                #baseline correction of 100ms (20 indexes), 400ms to 300ms before fb
                baseline_array = df[indx+b_s:indx+b_e, 1:57] 
                baseline_array = baseline_array.reshape((1,20,int(baseline_array.shape[1])))
                baseline_mean = np.mean(baseline_array, axis = 1)
                epoch_array[:,:,1:57] = epoch_array[:,:,1:57] - baseline_mean #noise subtracted from epoched data
            
            if i == 0:
                temp = np.vstack((temp,epoch_array)) #stacking the first epoch
            else:
                temp = np.vstack((temp,epoch_array))
                
    now = time.time()
    print('Elapsed Time: ' + str(int(now-start)) + ' seconds')
    return temp

In [None]:
train = extract_d(train_files)
test = extract_d(test_files)

0 data\Train\Data_S02_Sess01.csv (1, 100, 59)
1 data\Train\Data_S02_Sess02.csv (61, 100, 59)
2 data\Train\Data_S02_Sess03.csv (121, 100, 59)
3 data\Train\Data_S02_Sess04.csv (181, 100, 59)
4 data\Train\Data_S02_Sess05.csv (241, 100, 59)
5 data\Train\Data_S06_Sess01.csv (341, 100, 59)
6 data\Train\Data_S06_Sess02.csv (401, 100, 59)
7 data\Train\Data_S06_Sess03.csv (461, 100, 59)
8 data\Train\Data_S06_Sess04.csv (521, 100, 59)
