In [1]:
import numpy as np
import pandas as pd
import mne
import os
import re
from scipy.io import loadmat


# read raw EEG dataset of .gdf file
class RawEEGData:
    # rawData = list()
    # event = pd.DataFrame()
    # channel = list()
    # sample_freq = 0

    def __init__(self, file_name):
        raw = mne.io.read_raw_edf(file_name, preload=True, stim_channel=-1)
        event_type2idx = {276: 0, 277: 1, 768: 2, 769: 3, 770: 4, 781: 5, 783: 6, 1023: 7, 1077: 8, 1078: 9, 1079: 10,
                          1081: 11, 32766: 12}
        self.rawData = raw._data
        self.channel = raw._raw_extras[0]['ch_names']
        self.sample_freq = raw.info['sfreq']
        # self.rawData =
        self.event = pd.DataFrame({
            "length":raw._raw_extras[0]['events'][0],
            "position": raw._raw_extras[0]['events'][1],
            "event type": raw._raw_extras[0]['events'][2],
            "event index": [event_type2idx[event_type] for event_type in raw._raw_extras[0]['events'][2]],
            "duration": raw._raw_extras[0]['events'][4],
            "CHN": raw._raw_extras[0]['events'][3]
        })

    # print event type information of EEG data set
    @staticmethod
    def print_type_info():
        print("EEG data set event information and index:")
        print("%12s\t%10s\t%30s" % ("Event Type", "Type Index", "Description"))
        print("%12d\t%10d\t%30s" % (276, 0, "Idling EEG (eyes open)"))
        print("%12d\t%10d\t%30s" % (277, 1, "Idling EEG (eyes closed"))
        print("%12d\t%10d\t%30s" % (768, 2, "Start of a trial"))
        print("%12d\t%10d\t%30s" % (769, 3, "Cue onset left (class 1)"))
        print("%12d\t%10d\t%30s" % (770, 4, "Cue onset right (class 2)"))
        print("%12d\t%10d\t%30s" % (781, 5, "BCI feedback (continuous"))
        print("%12d\t%10d\t%30s" % (783, 6, "Cue unknown"))
        print("%12d\t%10d\t%30s" % (1023, 7, "Rejected trial"))
        print("%12d\t%10d\t%30s" % (1077, 8, "Horizontal eye movement"))
        print("%12d\t%10d\t%30s" % (1078, 9, "Vertical eye movement"))
        print("%12d\t%10d\t%30s" % (1079, 10, "Eye rotation"))
        print("%12d\t%10d\t%30s" % (1081, 11, "Eye blinks"))
        print("%12d\t%10d\t%30s" % (32766, 12, "Start of a new run"))


# arrange data for training and test
def get_data(data_file_dir, labels_file_dir):
    RawEEGData.print_type_info()
    sfreq = 250  # sample frequency of dataset
    # read data file
    data = dict()
    data_files = os.listdir(data_file_dir)
    for data_file in data_files:
        if not re.search(".*\.gdf", data_file):
            continue
        
        info = re.findall('B0([0-9])0([0-9])[TE]\.gdf', data_file)
        try:
            subject = "subject" + info[0][0]
            session = "session" + info[0][1]
            filename = data_file_dir + "\\" + data_file
            print(filename)
            raw_eeg_data = RawEEGData(filename)
            trial_event = raw_eeg_data.event[raw_eeg_data.event['event index'] == 2]
            session_data = dict()
            for event, event_data in trial_event.iterrows():
                trial_data = raw_eeg_data.rawData[:, event_data['position']:event_data['position']+event_data['duration']]
                for idx in range(len(raw_eeg_data.channel)):
                    if raw_eeg_data.channel[idx] not in session_data:
                        session_data[raw_eeg_data.channel[idx]] = list()
                    session_data[raw_eeg_data.channel[idx]].append(trial_data[idx])
            if subject not in data:
                data[subject] = dict()
            data[subject][session] = session_data
        except Exception as e:
            print(e)
            raise ("invalid data file name")

    # read data file
    labels = dict()
    labels_files = os.listdir(labels_file_dir)
    for labels_file in labels_files:
        if not re.search(".*\.mat", labels_file):
            continue

        info = re.findall('B0([0-9])0([0-9])[TE]\.mat', labels_file)
        try:
            subject = "subject" + info[0][0]
            session = "session" + info[0][1]
            filename = labels_file_dir + "\\" + labels_file
            print(filename)
            session_label = loadmat(filename)
            session_label = session_label['classlabel'].astype(np.int8)
            if subject not in labels:
                labels[subject] = dict()
            labels[subject][session] = session_label
        except Exception as e:
            print(e)
            raise ("invalid labels file name")

    return data, labels, sfreq
    # print(data)
    # print(labels)

Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\BCICIV_2b_gdf\B0401T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 604802  =      0.000 ...  2419.208 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


EEG data set event information and index:
  Event Type	Type Index	                   Description
         276	         0	        Idling EEG (eyes open)
         277	         1	       Idling EEG (eyes closed
         768	         2	              Start of a trial
         769	         3	      Cue onset left (class 1)
         770	         4	     Cue onset right (class 2)
         781	         5	      BCI feedback (continuous
         783	         6	                   Cue unknown
        1023	         7	                Rejected trial
        1077	         8	       Horizontal eye movement
        1078	         9	         Vertical eye movement
        1079	        10	                  Eye rotation
        1081	        11	                    Eye blinks
       32766	        12	            Start of a new run
U:\MEG-BCI\main\Deep Learning\subject234\B0401T.gdf
Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\subject234\B0401T.gdf...
GDF file detected
Setting channel info structure..

  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


U:\MEG-BCI\main\Deep Learning\subject234\B0402T.gdf
Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\subject234\B0402T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 696265  =      0.000 ...  2785.060 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


U:\MEG-BCI\main\Deep Learning\subject234\B0403T.gdf
Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\subject234\B0403T.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 468558  =      0.000 ...  1874.232 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


U:\MEG-BCI\main\Deep Learning\subject234\B0404E.gdf
Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\subject234\B0404E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 467478  =      0.000 ...  1869.912 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


U:\MEG-BCI\main\Deep Learning\subject234\B0405E.gdf
Extracting EDF parameters from U:\MEG-BCI\main\Deep Learning\subject234\B0405E.gdf...
GDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 466050  =      0.000 ...  1864.200 secs...


  etmode = np.fromstring(etmode, np.uint8).tolist()[0]


U:\MEG-BCI\main\Deep Learning\subject234\B0401T.mat
U:\MEG-BCI\main\Deep Learning\subject234\B0402T.mat
U:\MEG-BCI\main\Deep Learning\subject234\B0403T.mat
U:\MEG-BCI\main\Deep Learning\subject234\B0404E.mat
U:\MEG-BCI\main\Deep Learning\subject234\B0405E.mat
test: read dataset


In [2]:
import pandas as pd
import numpy as np
import matplotlib
from scipy.signal import butter, lfilter, stft
from scipy import interpolate
import math
from spectrum import pburg
from sklearn import preprocessing
# import matplotlib.pyplot as plt

#from read_data import get_data


# enrich data for each trail
def preprocess_signal( ori_data, start_time, slide_len, segment_len, num, sfreq ):
    processed_data = list()
    for i in range(num):
        left = int((start_time + i*slide_len)*sfreq)
        right = left + sfreq*segment_len
        # if need to be averaged
        data_foo = ori_data[left:right]
        data_foo = data_foo - np.mean(data_foo)

        processed_data.append(data_foo)
    return processed_data


# butterworth band pass filter design
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


# butterworth band pass filter
def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y


# combine all zhe data of all trials, session
def combine_processed_data(preprocessed_data, labels):
    combined_data = dict()
    combined_labels = dict()
    for subject in preprocessed_data:
        subject_data = pd.DataFrame()
        subject_labels = list()
        for session in preprocessed_data[subject]:
            subject_data = subject_data.append(preprocessed_data[subject][session])
            subject_labels.extend(labels[subject][session])
        subject_combined_data = pd.DataFrame()
        subject_combined_labels = list()
        labels_flag = True
        labels_idx = 0
        for channel in subject_data:
            channel_data = list()
            for trials_data in subject_data[channel]:
                for segment_data in trials_data:
                    channel_data.append(segment_data)
                    if labels_flag:
                        subject_combined_labels.append(subject_labels[labels_idx])
                labels_idx += 1
            subject_combined_data[channel] = channel_data
            labels_flag = False
        combined_data[subject] = subject_combined_data
        combined_labels[subject] = subject_combined_labels

    return combined_data, combined_labels


# subject optimal frequency bands selection methods based on Band Pass feature
# type = 0, BP features
# type = 1, AR features
def feature_band_selection(data, labels, sfreq, step=1, band_range = (0, 0), band_size=(0, 0, 0),
                              channel=('EEG:C3', 'EEG:Cz', 'EEG:C4'), features_type=0):
    # AR model parameters
    ar_order = 12
    nfft = 1000

    subject_optimal_frequency_bands = dict()

    # mu band selection
    for subject in data:
        if features_type == 1: # compute AR Model PSD based on burg algorithm
            ar_psd = dict()
            freq_flag = True
            for channel_name in channel:
                ar_psd[channel_name] = list()
            for idx in range(len(labels[subject])):
                for channel_name in channel:
                    x = data[subject][channel_name][idx]
                    p = pburg(x, order=ar_order, NFFT=nfft, sampling=sfreq, scale_by_freq=True)
                    if freq_flag:
                        ar_psd['frequency'] = np.array(p.frequencies())
                        freq_flag = False
                    ar_psd[channel_name].append(p.psd)

        f_score = list()
        optimal_band = list()
        for band in band_size:
            for num_windows in range(int((band_range[1]-band_range[0]-band)/step)):
                lowcut = band_range[0] + num_windows * step
                highcut = lowcut + band
                optimal_band.append((lowcut, highcut))
                if features_type == 1:
                    ar_freq = ar_psd['frequency']
                    psd_idx_start = np.where(ar_freq >= lowcut)[0][0]
                    psd_idx_end = np.where(ar_freq >= highcut)[0][0]
                left_features = list()
                right_features = list()
                for idx in range(len(labels[subject])):
                    features = list()
                    for channel_name in channel:
                        # BP features
                        if features_type == 0: # 5th butterworth filter
                            filtered_data = butter_bandpass_filter(data[subject][channel_name][idx], lowcut, highcut,
                                                                   sfreq, order=5)
                        elif features_type == 1: # AR model PSD
                             filtered_data = ar_psd[channel_name][idx][psd_idx_start : psd_idx_end]
                        else:
                            raise Exception("feature type wrong!\n band pass features: features_type=0\n "
                                            "AR PSD features: features_type=1")
                        features.append(math.log10(np.var(filtered_data)))
                    if labels[subject][idx] == 1:
                        left_features.append(features)
                    elif labels[subject][idx] == 2:
                        right_features.append(features)
                left_mean_val = np.mean(left_features, axis=0)
                right_mean_val = np.mean(right_features, axis=0)
                left_var = np.var(left_features, axis=0)
                right_var = np.var(right_features, axis=0)
                f_score.append(sum(np.square(left_mean_val-right_mean_val)) / sum(left_var+right_var))
        # get optimal frequency corresponding to max F-score
        subject_optimal_frequency_bands[subject] = optimal_band[f_score.index(max(f_score))]
        # pause = input("pause")
    return subject_optimal_frequency_bands

# find kth largest number in a 1-d array
def find_kth_largest(arr, k):
    k = k - 1
    lo = 0
    hi = len(arr) - 1
    while lo < hi:
        arr[lo], arr[int((lo+hi)/2)] = arr[int((lo+hi)/2)], arr[lo]
        left = lo; right = hi; pivot=arr[lo];
        while left < right:
            while left < right and arr[right] <= pivot:
                right = right - 1
            arr[left] = arr[right]
            while left < right and arr[left] >= pivot:
                left = left + 1
            arr[right] = arr[left]
        arr[left] = pivot
        if k <= left:
            hi = left - 1
        if k >= left:
            lo = left + 1
    res = arr[k]
    return res


# rescale data
# @percentage: percentage of value considered to be artifact
def recale(data, percentage):
    m, n = data.shape
    arr = data.flatten()
    min_val = np.min(arr)
    max_val = find_kth_largest(arr, int(m*n*percentage))
    for i in range(m):
        for j in range(n):
            if data[i][j] > max_val:
                data[i][j] = 1
            else:
                data[i][j] = (data[i][j] - min_val) / ( max_val - min_val)
    return data


# get the input data of CNN by STFT
def get_input_data(data, mu_band, beta_band, channel=('EEG:C3', 'EEG:Cz', 'EEG:C4')):
    # parameters of stft:
    wlen = 64  # length of the analysis Hamming window
    nfft = 512  # number of FFT points
    fs = 250  # sampling frequency, Hz
    hop = 14  # hop size

    input_data = list()
    num_segments = len(data[channel[0]])
    freq_flag = True
    for idx in range(num_segments):
        input_image = None
        for chn in channel:
            f, t, Fstft = stft(data[chn][idx], fs=fs, window='hamming', nperseg=wlen, noverlap=wlen-hop,
                               nfft=nfft, return_onesided=True, boundary=None, padded=False)
            if freq_flag:  # only need run one time
                mu_left = np.where(f >= mu_band[0])[0][0]
                mu_right = np.where(f >= mu_band[1])[0][0]
                beta_left = np.where(f >= beta_band[0])[0][0]
                beta_right = np.where(f >= beta_band[1])[0][0]
                freq_flag = False
            mu_feature_matrix = np.abs(Fstft[mu_left : mu_right])
            beta_feature_matrix = np.abs(Fstft[beta_left : beta_right])

            # beta band cubic interpolation
            beta_interp = interpolate.interp2d(t, f[beta_left : beta_right], beta_feature_matrix, kind='cubic')
            interNum = len(mu_feature_matrix)
            f_beta = np.arange(beta_band[0], beta_band[1], (beta_band[1]-beta_band[0])/(interNum))
            beta_feature_matrix = beta_interp(t, f_beta)
            # mu_feature_matrix = preprocessing.scale(np.array(mu_feature_matrix), axis=1)
            mu_feature_matrix = recale(mu_feature_matrix, 0.05)
            beta_feature_matrix = recale(beta_feature_matrix, 0.05)
            # mu_feature_matrix = preprocessing.scale(beta_feature_matrix, axis=1)
            # plt.pcolormesh(t, f_beta, beta_feature_matrix, vmin=0)
            # plt.show()
            # pause = input("pause")
            if input_image is None:
                input_image = np.append(mu_feature_matrix, beta_feature_matrix, axis=0)
            else:
                input_image = np.append(input_image, mu_feature_matrix, axis=0)
                input_image = np.append(input_image, beta_feature_matrix, axis=0)
        input_data.append(input_image)
    return input_data



# default run function
# @band_type = 0: band pass optimal frequency bands
# @band_type = 1: AR PSD optimal frequency bands
# @band_type = 2: extend frequency band
def run_sig_processing(data_src, labels_src, band_type):
    # parameters initialization
    start_time = 3
    time_slides = 0.2
    window_length = 2
    segments_num = 11

    data, labels, sfreq = get_data(data_src, labels_src)

    # execute
    preprocessed_data = dict()
    for subject in data:
        if subject not in preprocessed_data:
            preprocessed_data[subject] = dict()
        for session in data[subject]:
            df_trials_data = pd.DataFrame()
            for channel in data[subject][session]:
                session_data = data[subject][session][channel]
                trials_processed_data = list()
                for trial_data in session_data:
                    processed_data = preprocess_signal(trial_data, start_time, time_slides, window_length,
                                                       segments_num, sfreq)
                    trials_processed_data.append(processed_data)
                df_trials_data[channel] = trials_processed_data
            # print(df_trials_data)
            # pause = input("pause: ")
            preprocessed_data[subject][session] = df_trials_data

    if band_type == 0 or band_type == 1:
        combined_data, combined_labels = combine_processed_data(preprocessed_data, labels)
        mu_band = feature_band_selection(combined_data, combined_labels, sfreq, step=1, band_range=(4, 14),
                                          band_size=(4, 5, 6), features_type=band_type)
        beta_band = feature_band_selection(combined_data, combined_labels, sfreq, step=1, band_range=(14, 32),
                                          band_size=(4, 5, 6), features_type=band_type)
    else:
        mu_band = dict()
        beta_band = dict()
        for subject in preprocessed_data:
            mu_band[subject] = (8, 12)
            beta_band[subject] = (13, 30)

    # get input data of CNN, add to column of dataFrame form processed_data[subject][session]
    for subject in preprocessed_data:
        for session in preprocessed_data[subject]:
            preprocessed_data[subject][session]['input data'] \
                = preprocessed_data[subject][session].apply(get_input_data, axis=1,
                                                            mu_band=mu_band[subject], beta_band=beta_band[subject])

    return preprocessed_data, labels



In [1]:
import numpy as np
from sklearn.model_selection import KFold
import pickle
import tensorflow as tf
from tensorflow import keras
import pandas as pd
import matplotlib.pyplot as plt

#from signalProcessing import run_sig_processing

# get train data and labels for each segment
def arrange_data(data, labels):
    output_data = list()
    output_labels = list()
    for idx in range(len(data)):
        for segment in data[idx]:
            output_data.append(np.expand_dims(segment, axis=2))
            if labels[idx][0] == 1:
                output_labels.append(0)
            else:
                output_labels.append(1)
    output_data = np.array(output_data)
    output_labels = np.array(output_labels)
    return output_data, output_labels


# build model
def build_model(size_y, size_x):
    # input layer
    img_input = keras.layers.Input(shape=(size_y, size_x, 1))
    nf = size_y-2
    x = keras.layers.Conv2D(filters= nf, kernel_size=(size_y, 3), activation='relu', 
                            kernel_regularizer=keras.regularizers.l2(0))(img_input)
    x = keras.layers.MaxPooling2D(1, 10)(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dropout(0.5)(x)
    output = keras.layers.Dense(2, activation='sigmoid')(x)
    model = keras.models.Model(img_input, output)
    model.summary()
    model.compile(loss='categorical_crossentropy',
                  optimizer=keras.optimizers.SGD(lr=0.001, momentum=0.9),
                  metrics=['acc'])
    return model

# evaluated trial to trial performance
def trial_evaluate(model, data, labels):
    acc = 0.0
    for idx in range(len(data)):
        test_data, test_label = arrange_data(np.expand_dims(data[idx], axis=0), np.expand_dims(labels[idx], axis=0))
        test_label = keras.utils.to_categorical(test_label, num_classes=2)
        loss, accuracy = model.evaluate(test_data, test_label)
        if accuracy > 0.5:
            acc += 1.0
    acc = acc/len(data)
    return acc


# run classification
def run_classification(data, labels, session=(1, 2, 3, 4, 5)):
    kf = KFold(n_splits=10, shuffle=True)
    classification_acc = pd.DataFrame()
    for subject in data:
        # if subject == 'subject1': continue
        subject_acc = list()
        input_data = list()
        target_labels = list()
        # combine trials data of target session
        [input_data.extend(data[subject]["session" + str(idx)]['input data']) for idx in session]
        [target_labels.extend(labels[subject]["session" + str(idx)]) for idx in session]
        input_data = np.array(input_data)
        target_labels = np.array(target_labels)

        # 10 fold cross-validation
        count = 0
        for train_index, test_index in kf.split(input_data):
            count += 1
            train_data, train_labels = arrange_data(input_data[train_index], target_labels[train_index])
            test_data, test_labels = arrange_data(input_data[test_index], target_labels[test_index])

            size_y, size_x = train_data[0].shape[0:2]

            print(train_data.shape)
            # train_data_size = train_data.shape[0]
            # test_data_size = test_data.shape[0]

            train_labels = keras.utils.to_categorical(train_labels, num_classes=2)
            test_labels = keras.utils.to_categorical(test_labels, num_classes=2)


            # build model
            model = build_model(size_y, size_x)

            print('Training ------------')
            # train the model
            history = model.fit(train_data, train_labels, validation_split=0.33, epochs=300, batch_size=40)
            print(history.history.keys())
            # summarize history for accuracy
            plt.plot(history.history['acc'])
            plt.plot(history.history['val_acc'])
            plt.title('model accuracy')
            plt.ylabel('accuracy')
            plt.xlabel('epoch')
            plt.legend(['train', 'test'], loc='upper left')
            plt.show()
            # summarize history for loss
            plt.plot(history.history['loss'])
            plt.plot(history.history['val_loss'])
            plt.title('model loss')
            plt.ylabel('loss')
            plt.xlabel('epoch')
            plt.legend(['train', 'test'], loc='upper left')
            plt.show()

            print('\nTesting ------------')
            # Evaluate the model with the metrics we defined earlier
            loss, accuracy = model.evaluate(test_data, test_labels)

            trial_acc = trial_evaluate(model, input_data[test_index], target_labels[test_index])
            print(count, subject)
            print('test loss: ', loss)
            print('test accuracy: ', accuracy)
            print('trial to trial accuracy: ', trial_acc)
            subject_acc.append(trial_acc)
        classification_acc[subject] = subject_acc
    return classification_acc




if __name__ == '__main__':
    # '''
    data_src = r"U:\MEG-BCI\main\Deep Learning\subject234"
    labels_src = r"U:\MEG-BCI\main\Deep Learning\subject234"
    data, labels = run_sig_processing(data_src, labels_src, band_type=3)

    # Saving the data and labels:
    # with open('temp_data.pkl', 'wb') as f:  # Python 3: open(..., 'wb')
    #     pickle.dump([data, labels], f)
    # '''
    # Getting back the data and labels:
    # with open('temp_data.pkl', 'rb') as f:  # Python 3: open(..., 'rb')
    #     data, labels = pickle.load(f)
    res = run_classification(data, labels)
    print(res)
    res.to_csv("BP_acc2_beta_gamma.csv", encoding="utf-8")
    print("cnn classification")

KeyboardInterrupt: 