In [2]:
from sklearn.preprocessing import LabelEncoder, OneHotEncoder, StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from imblearn.metrics import sensitivity_specificity_support
from sklearn.model_selection import train_test_split
from sklearn.utils.multiclass import unique_labels
from tqdm import tqdm_notebook as tqdm
from scipy.integrate import simps
import matplotlib.pyplot as plt
from scipy import signal
import seaborn as sns
import pandas as pd
import numpy as np
import warnings
import glob
import mne
import os



warnings.filterwarnings('ignore')

Using TensorFlow backend.


In [7]:
def removeMeanFromChannels(channel_data):
    channel_data = channel_data.transpose()
    scaled_data = StandardScaler(with_std = False).fit_transform(channel_data)
    return scaled_data.transpose()
    
def getMneRaw(data_to_process):
    sfreq = 128
    channel_names = ['F3', 'Fz', 'F4', 'C3', 'Cz', 'C4', 'P3',
                     'P4', 'FC5', 'FC1', 'FC2', 'FC4', 'CP5',
                     'CP1', 'CP2', 'CP4']
    channel_type = {k:'eeg' for k in channel_names}
    info = mne.create_info(channel_names, sfreq,verbose=False)
    raw = mne.io.RawArray(removeMeanFromChannels(data_to_process), info, verbose=False)
    raw.set_channel_types(channel_type)
    montage = mne.channels.read_montage('standard_1020')
    raw.set_montage(montage,verbose=False)
    raw.filter(0,50,fir_design='firwin',verbose=False)
    return raw

def getFileInfo(file_path):
    try:
        raw_data = pd.read_csv(file_path,header=None)
        data_to_process = raw_data[:-1]
        label = int( list(set(raw_data.iloc[len(raw_data) -1 ].values))[0] )
        raw = getMneRaw(data_to_process)
        return [ raw, label ]
    except:
        return [0, 0]
    
def getPSDFeatures(data):
    sf = 128
    window = 2*sf
    overlap = window//2
    freqs, psd = signal.welch(data, sf, nperseg=window, noverlap = overlap)
    alpha_indexes = np.where((freqs >= 8)&(freqs < 13))
    alpha_values = psd[alpha_indexes]
    alpha_values.sort()

    beta_indexes = np.where((freqs >= 13)&(freqs < 30))
    beta_values = psd[beta_indexes]
    beta_values.sort()
    features = [alpha_values[-1],beta_values[-1],beta_values[-2]]
    return features


def bandpower(eeg,sf=128,window_sec=None, relative=False):
    band_power = {}
    for channel_name,data in list(zip(eeg.info['ch_names'], eeg.get_data())):
        band_dic = {}
        for band_name in bands:
            band = np.asarray(bands[band_name])
            low, high = band
            if window_sec is not None:
                nperseg = window_sec * sf
            else:
                nperseg = (2 / low) * sf
            freqs, psd = signal.welch(data, sf, nperseg=nperseg)
            freq_res = freqs[1] - freqs[0]
            idx_band = np.logical_and(freqs >= low, freqs <= high)
            bp = simps(psd[idx_band], dx=freq_res)
            if relative:
                bp /= simps(psd, dx=freq_res)
            band_dic[band_name]= bp
        band_power[channel_name] = band_dic
    return band_power

In [4]:
files = glob.glob('../data/open_bci/2colourswifi_transition/*.csv')
bands = {
    'alpha':[8,13],'beta':[13,30],'gamma':[30,200],
'delta':[1,4],'theta':[4,8]
}

In [8]:
training_data_list = []
training_label_list = []

for i in tqdm(files):
    raw_data_label = getFileInfo(i)
    file_features = []
    if raw_data_label[0] != 0:
        band_feature = bandpower(raw_data_label[0])
        for channel in band_feature:
            file_features.append(list(band_feature[channel].values()))
        break
        '''
        raw_data = raw_data_label[0].get_data()
        label = raw_data_label[1]
        for channel in raw_data:
            training_data_list.append(getPSDFeatures(channel))
        training_label_list.append(list(map(int,('{},'.format(label)*raw_data.shape[0]).split(',')[:-1])))
        '''

HBox(children=(IntProgress(value=0, max=288), HTML(value='')))




In [15]:
file_features = []
for channel in band_feature:
    print(list(band_feature[channel].values()))

[6.364429519938637e-09, 1.51816715713491e-09, 2.9590987703923743e-10, 2.535204695326041e-11, 2.1355025379543975e-08]
[1.7879294057513403e-09, 4.972874697504084e-10, 1.8332226225494804e-10, 4.329324206287494e-11, 6.0886039271932916e-09]
[1.6346289355306776e-08, 3.583945249914468e-09, 6.727950569404751e-10, 3.194454644974805e-11, 5.489606949986732e-08]
[3.006683346679199e-09, 6.265584832847615e-10, 1.3790082793909014e-10, 2.4878682281726214e-11, 1.0094525597687092e-08]
[3.264401328958905e-10, 1.3778813897300746e-10, 5.4710350925434166e-11, 2.2358148622010396e-11, 1.0452664882273415e-09]
[1.147052164024975e-09, 3.125217472504673e-10, 7.47968486195624e-11, 2.2764484772000054e-11, 3.827183215580718e-09]
[2.7140234861669213e-09, 5.788168871374607e-10, 1.083072898173398e-10, 2.4813962934003685e-11, 9.097073368257181e-09]
[2.1564050415810278e-09, 4.5148291094986516e-10, 9.985792671053953e-11, 3.52852505911431e-11, 7.212017905150155e-09]
[5.578812525895981e-09, 1.3612712110858756e-09, 3.8116569

In [16]:
band_feature

{'F3': {'alpha': 6.364429519938637e-09,
  'beta': 1.51816715713491e-09,
  'gamma': 2.9590987703923743e-10,
  'delta': 2.535204695326041e-11,
  'theta': 2.1355025379543975e-08},
 'Fz': {'alpha': 1.7879294057513403e-09,
  'beta': 4.972874697504084e-10,
  'gamma': 1.8332226225494804e-10,
  'delta': 4.329324206287494e-11,
  'theta': 6.0886039271932916e-09},
 'F4': {'alpha': 1.6346289355306776e-08,
  'beta': 3.583945249914468e-09,
  'gamma': 6.727950569404751e-10,
  'delta': 3.194454644974805e-11,
  'theta': 5.489606949986732e-08},
 'C3': {'alpha': 3.006683346679199e-09,
  'beta': 6.265584832847615e-10,
  'gamma': 1.3790082793909014e-10,
  'delta': 2.4878682281726214e-11,
  'theta': 1.0094525597687092e-08},
 'Cz': {'alpha': 3.264401328958905e-10,
  'beta': 1.3778813897300746e-10,
  'gamma': 5.4710350925434166e-11,
  'delta': 2.2358148622010396e-11,
  'theta': 1.0452664882273415e-09},
 'C4': {'alpha': 1.147052164024975e-09,
  'beta': 3.125217472504673e-10,
  'gamma': 7.47968486195624e-11,
  

In [45]:
X = np.array(training_data_list)
Y = sum(training_label_list,[])

In [46]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size = 0.3)

In [82]:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.linear_model import LogisticRegression

In [90]:
classifier = LogisticRegression()
classifier.fit(x_train,y_train)
y_pred = classifier.predict(x_test)
classifier.score(x_test,y_test)

0.5241935483870968

In [93]:
basic_metrics = pd.DataFrame(classification_report(y_test, y_pred, output_dict = True))
basic_metrics = basic_metrics[basic_metrics.columns[:2]]
basic_metrics.columns = [1,2]
other_metrics = pd.DataFrame(sensitivity_specificity_support(y_test, y_pred), columns=[1,2])
other_metrics.index = ['sensitivity','specificity','support']
metric = pd.concat([basic_metrics,other_metrics]).drop_duplicates()

In [95]:
metric

Unnamed: 0,1,2
precision,0.0,0.524194
recall,0.0,1.0
f1-score,0.0,0.687831
support,649.0,715.0
specificity,1.0,0.0
