In [1]:
# Python standard libraries
import math

# Packages for computation and modelling
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy.stats import norm
import mne
import pickle

# Packages for visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Self-defined packages
from swlda import SWLDA
from utils import *

# Magic command to reload packages whenever we run any later cells
%load_ext autoreload
%autoreload 2

In [2]:
BOARD = [["A",    "B",  "C",   "D",    "E",    "F",     "G",    "H"    ],
         ["I",    "J",  "K",   "L",    "M",    "N",     "O",    "P"    ],
         ["Q",    "R",  "S",   "T",    "U",    "V",     "W",    "X"    ],
         ["Y",    "Z",  "Sp",  "1",    "2",    "3",     "4",    "5"    ],
         ["6",    "7",  "8",   "9",    "0",    "Prd",   "Ret",  "Bs"   ],
         ["?",    ",",  ";",   "\\",   "/",    "+",     "-",    "Alt"  ],
         ["Ctrl", "=",  "Del", "Home", "UpAw", "End",   "PgUp", "Shft" ],
         ["Save", "'",  "F2",  "LfAw", "DnAw", "RtAw",  "PgDn", "Pause"],
         ["Caps", "F5", "Tab", "EC",   "Esc",  "email", "!",    "Sleep"]]
BOARD  = np.array(BOARD)
N_ROWS = BOARD.shape[0]  # number of rows
N_COLS = BOARD.shape[1]  # number of columns
M = N_ROWS * N_COLS      # the number of chars on the board

In [3]:
paradigm       = 'RC'  # display paradigm ('RC', 'CB', or 'RD')
NUM_TIMESTAMPS = 195   # number of timestamps in each window to record signals
EPOCH_SIZE     = 15    # required number of features in every epoch
CORE_CHANNELS = ('EEG_Fz', 'EEG_Cz',  'EEG_P3',  'EEG_Pz',
                 'EEG_P4', 'EEG_PO7', 'EEG_PO8', 'EEG_Oz')
NUM_CORE_CHANNELS  = len(CORE_CHANNELS)  # number of core eletrodes
NUM_TRAIN_WORDS = 5 # number of training words for one participant
NUM_TEST_WORDS  = 5 # number of testing words for one participant

In [8]:
obj_indices = ['01', '02', '03', '04', '05', '06', '07',
               '09', '14', '15', '16', '17', '19']

AUCs = []
all_performance = []

for obj in obj_indices:
    directory = '/Users/zionshane/Desktop/Duke/Research/BCI_data/EDFData-StudyA'
    obj_directory = directory + f'/A{obj}/SE001'

    train_features,train_response = load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CORE_CHANNELS,
                                              type=paradigm,
                                              mode='train',
                                              num_words=NUM_TRAIN_WORDS)

    test_features,test_response   = load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CORE_CHANNELS,
                                              type=paradigm,
                                              mode='test',
                                              num_words=NUM_TEST_WORDS)

    try:
        f = open(f'./model/A{obj}-model.pkl', 'rb')
        clf = pickle.load(f)
    except:
        clf = SWLDA(penter=0.1, premove=0.15)
        clf.fit(train_features, train_response)
        # save the classifier as a standalone model file
        with open(f'./model/A{obj}-model.pkl','wb') as f:
            pickle.dump(clf,f)

    auc = clf.test(test_features, test_response)
    AUCs.append(auc)

    scores = pd.DataFrame(clf.test(train_features), columns=['score'])
    scores['is_target'] = train_response.astype('int')
    mu_1, std_1 = norm.fit(data=scores.loc[scores['is_target'] == 1]['score'])
    mu_0, std_0 = norm.fit(data=scores.loc[scores['is_target'] == 0]['score'])
    var_1 = std_1**2
    var_0 = std_0**2

    test_file_indices = ['06', '07', '08', '09', '10']
    participant_perform = []

    for file_index in test_file_indices:
        test_file = directory + ('/A%s/SE001/Test/%s/A%s_SE001%s_Test%s.edf'
                    % (obj, paradigm, obj, paradigm, file_index))
        raw_data = mne.io.read_raw_edf(test_file, preload=True, verbose=False)

        stim_events = mne.find_events(raw=raw_data,
                                      stim_channel='StimulusBegin',
                                      verbose=False)
        eeg_channels = mne.pick_channels_regexp(raw_data.info['ch_names'],'EEG')
        raw_data.notch_filter(freqs=60, picks=eeg_channels, verbose=False)
        test_epochs = get_core_epochs(raw_data)

        current_target_events=mne.find_events(raw_data,
                                              stim_channel='CurrentTarget',
                                              verbose=False)
        current_target = current_target_events[:,2]
        truth = eventIDs_to_strings(BOARD, current_target)

        phases_events = mne.find_events(raw_data,
                                        stim_channel='PhaseInSequence',
                                        verbose=False)
        phases_appears = phases_events[:,0]
        during_trail_phases = []
        for i in range(1, len(phases_appears), 2):
            start = phases_appears[i]
            end = phases_appears[i+1]
            during_trail_phases.append((start, end))

        test_features, test_response = split_data(test_epochs,
                                                  n_channels=NUM_CORE_CHANNELS,
                                                  n_times=NUM_TIMESTAMPS,
                                                  n_samples=EPOCH_SIZE)
        stim_begin_events=mne.find_events(raw=raw_data,
                                          stim_channel='StimulusBegin',
                                          verbose=False)
        stim_begin_time = stim_begin_events[:,0]

        flashing_schedule=get_flashing_schedule(BOARD,raw_data,stim_begin_time)

        clf_scores = clf.test(data=test_features)

        NUM_SEQ = 7
        T_MAX = (N_ROWS + N_COLS) * NUM_SEQ # max number of flashes per trial
        ACTUAL_T_MAX = int(len(stim_begin_time)/len(truth))
        P_threshold = 0.9

        trail_perform = {'truth':list(truth), 'select':[], 'times':[]}

        for trail in range(len(during_trail_phases)):
            P_all = np.ones(shape=(N_ROWS, N_COLS)) * (1/M) # Initialize probs
            num_flashes = 0
            target_index = current_target[trail]
            target_loc = ((target_index-1) // N_COLS, (target_index-1) % N_COLS)
            start, end = during_trail_phases[trail]
            time = start
            k = 0

            while time < end:
                num_flashes += 1
                flashed = flashing_schedule[time]
                # Generate the classifier score
                y = clf_scores[trail*ACTUAL_T_MAX + k]
                # Update probabilities
                for i in range(N_ROWS):
                    for j in range(N_COLS):
                        ch_index = N_COLS * i + j + 1
                        if (ch_index in flashed):
                            likelihood=stats.norm.pdf(x=y,loc=mu_1,scale=std_1)
                        else:
                            likelihood=stats.norm.pdf(x=y,loc=mu_0,scale=std_0)
                        P_all[i, j] = P_all[i, j] * likelihood
                # Normalize P_all
                P_all = P_all / P_all.sum()
                # Check if can stop
                if P_all.max() >= P_threshold:
                    break
                else:
                    k += 1
                    if trail*ACTUAL_T_MAX + k == len(stim_begin_time):
                        break
                    else:
                        time = stim_begin_time[trail*ACTUAL_T_MAX + k]

            max_loc = np.unravel_index(P_all.argmax(), P_all.shape)
            trail_perform['select'].append(BOARD[max_loc])
            trail_perform['times'].append(num_flashes)

        participant_perform.append(trail_perform)

    all_performance.append(participant_perform)

In [9]:
all_performance # check results [passed!]

[[{'truth': ['D', 'R', 'I', 'V', 'I', 'N', 'G'],
   'select': ['D', 'R', 'I', 'V', 'I', 'N', 'G'],
   'times': [56, 49, 31, 40, 51, 71, 52]},
  {'truth': ['Q', 'U', 'I', 'C', 'K', 'L', 'Y'],
   'select': ['Q', 'U', 'I', 'C', 'K', 'L', 'Y'],
   'times': [56, 67, 46, 47, 85, 87, 41]},
  {'truth': ['T', 'O', 'W', 'A', 'R', 'D', 'S'],
   'select': ['L', 'O', 'W', 'A', 'R', 'D', 'S'],
   'times': [42, 39, 37, 96, 68, 99, 37]},
  {'truth': ['D', 'A', 'Y', 'L', 'I', 'G', 'H', 'T'],
   'select': ['L', 'A', 'Y', 'L', 'I', 'G', 'H', 'T'],
   'times': [64, 43, 47, 42, 13, 57, 53, 42]},
  {'truth': ['5', '1', '4', '9', '7', '3', '6'],
   'select': ['5', '1', '4', '9', '7', 'Sp', '6'],
   'times': [83, 35, 69, 61, 20, 119, 67]}],
 [{'truth': ['D', 'R', 'I', 'V', 'I', 'N', 'G'],
   'select': ['D', 'R', 'I', 'V', 'I', 'N', 'G'],
   'times': [42, 34, 31, 35, 32, 56, 18]},
  {'truth': ['Q', 'U', 'I', 'C', 'K', 'L', 'Y'],
   'select': ['R', 'U', 'I', 'C', 'K', 'L', 'Y'],
   'times': [27, 35, 58, 43, 25,

In [10]:
AUCs

[0.851303329973965,
 0.8656961241286638,
 0.8225198412698413,
 0.9090524481397498,
 0.823358108675569,
 0.9414645796590241,
 0.8652557319223986,
 0.6910378348870412,
 0.8434996220710507,
 0.7373162845385068,
 0.7374490845721005,
 0.7826840304022843,
 0.7807481943394642]