In [2]:
# Python standard libraries
import math

# Packages for computation and modelling
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy.stats import norm
import mne
import pickle

# Packages for visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Self-defined packages
from swlda import SWLDA
from utils import *

# Magic command to reload packages whenever we run any later cells
%load_ext autoreload
%autoreload 2

In [3]:
BOARD = [["A",    "B",  "C",   "D",    "E",    "F",     "G",    "H"    ],
         ["I",    "J",  "K",   "L",    "M",    "N",     "O",    "P"    ],
         ["Q",    "R",  "S",   "T",    "U",    "V",     "W",    "X"    ],
         ["Y",    "Z",  "Sp",  "1",    "2",    "3",     "4",    "5"    ],
         ["6",    "7",  "8",   "9",    "0",    "Prd",   "Ret",  "Bs"   ],
         ["?",    ",",  ";",   "\\",   "/",    "+",     "-",    "Alt"  ],
         ["Ctrl", "=",  "Del", "Home", "UpAw", "End",   "PgUp", "Shft" ],
         ["Save", "'",  "F2",  "LfAw", "DnAw", "RtAw",  "PgDn", "Pause"],
         ["Caps", "F5", "Tab", "EC",   "Esc",  "email", "!",    "Sleep"]]
BOARD  = np.array(BOARD)
N_ROWS = BOARD.shape[0]  # number of rows
N_COLS = BOARD.shape[1]  # number of columns
M = N_ROWS * N_COLS      # the number of chars on the board

In [13]:
paradigm       = 'RC'  # display paradigm ('RC', 'CB', or 'RD')
NUM_TIMESTAMPS = 195   # number of timestamps in each window to record signals
EPOCH_SIZE     = 15    # required number of features in every epoch
CORE_CHANNELS = ('EEG_Fz', 'EEG_Cz',  'EEG_P3',  'EEG_Pz',
                 'EEG_P4', 'EEG_PO7', 'EEG_PO8', 'EEG_Oz')
NUM_CORE_CHANNELS  = len(CORE_CHANNELS)  # number of core eletrodes
NUM_TRAIN_WORDS = 6 # number of training sessions for one participant
NUM_TEST_WORDS  = 6 # number of testing sessions for one participant

obj = 1 # the index of experiment object (participant)
obj = str(obj) if obj > 10 else '0'+str(obj)
directory = '/Users/zionshane/Desktop/Duke/Research/BCI_data/EDFData-StudyA'
obj_directory = directory + f'/A{obj}/SE001'

In [5]:
with open(f'./model/A{obj}-model.pkl', 'rb') as f:
    clf = pickle.load(f)

In [6]:
train_features,train_response = load_data(dir=obj_directory,
                                          obj=obj,
                                          num_timestamps=NUM_TIMESTAMPS,
                                          epoch_size=EPOCH_SIZE,
                                          num_channels=NUM_CORE_CHANNELS,
                                          type=paradigm,
                                          mode='train',
                                          num_trials=NUM_TRAIN_WORDS)

In [7]:
scores = pd.DataFrame(clf.test(train_features), columns=['score'])
scores['is_target'] = train_response.astype('int')
mu_1, std_1 = norm.fit(data=scores.loc[scores['is_target'] == 1]['score'])
mu_0, std_0 = norm.fit(data=scores.loc[scores['is_target'] == 0]['score'])
var_1 = std_1**2
var_0 = std_0**2

In [16]:
test_file_indices = ['06', '07', '08', '09', '10']
participant_perform = []

for file_index in test_file_indices:
    test_file = directory + ('/A%s/SE001/Test/%s/A%s_SE001%s_Test%s.edf'
                % (obj, paradigm, obj, paradigm, file_index))
    raw_data = mne.io.read_raw_edf(test_file, preload=True, verbose=False)

    stim_events = mne.find_events(raw=raw_data,
                                  stim_channel='StimulusBegin',
                                  verbose=False)
    eeg_channels = mne.pick_channels_regexp(raw_data.info['ch_names'], 'EEG')
    raw_data.notch_filter(freqs=60, picks=eeg_channels, verbose=False)
    test_epochs = get_core_epochs(raw_data)

    current_target_events=mne.find_events(raw_data, stim_channel='CurrentTarget',
                                          verbose=False)
    current_target = current_target_events[:,2]
    truth = eventIDs_to_strings(BOARD, current_target)

    phases_events = mne.find_events(raw_data, stim_channel='PhaseInSequence',
                                    verbose=False)
    phases_appears = phases_events[:,0]
    during_trail_phases = []
    for i in range(1, len(phases_appears), 2):
        start = phases_appears[i]
        end = phases_appears[i+1]
        during_trail_phases.append((start, end))

    test_features, test_response = split_data(test_epochs,
                                              n_channels=NUM_CORE_CHANNELS,
                                              n_times=NUM_TIMESTAMPS,
                                              n_samples=EPOCH_SIZE)
    stim_begin_events=mne.find_events(raw=raw_data, stim_channel='StimulusBegin',
                                      verbose=False)
    stim_begin_time = stim_begin_events[:,0]

    flashing_schedule = get_flashing_schedule(BOARD, raw_data, stim_begin_time)

    clf_scores = clf.test(data=test_features)

    NUM_SEQ = 7
    T_MAX = (N_ROWS + N_COLS) * NUM_SEQ # maximum number of flashes in a trial
    ACTUAL_T_MAX = int(len(stim_begin_time)/len(truth))
    P_threshold = 0.9

    trail_perform = {'truth':list(truth), 'select':[], 'times':[]}

    for trail in range(len(during_trail_phases)):
        P_all = np.ones(shape=(N_ROWS, N_COLS)) * (1/M) # Initialize probs
        num_flashes = 0
        target_index = current_target[trail]
        target_loc = ((target_index-1) // N_COLS, (target_index-1) % N_COLS)
        start, end = during_trail_phases[trail]
        time = start
        k = 0

        while time <= end:
            num_flashes += 1
            flashed = flashing_schedule[time]
            # Generate the classifier score
            y = clf_scores[trail*ACTUAL_T_MAX + k]
            # Update probabilities
            for i in range(N_ROWS):
                for j in range(N_COLS):
                    ch_index = N_COLS * i + j + 1
                    if (ch_index in flashed):
                        likelihood = stats.norm.pdf(x=y, loc=mu_1, scale=std_1)
                    else:
                        likelihood = stats.norm.pdf(x=y, loc=mu_0, scale=std_0)
                    P_all[i, j] = P_all[i, j] * likelihood
            # Normalize P_all
            P_all = P_all / P_all.sum()
            # Check if can stop
            if P_all.max() >= P_threshold:
                break
            else:
                k += 1
                time = stim_begin_time[trail*ACTUAL_T_MAX + k]

        max_loc = np.unravel_index(P_all.argmax(), P_all.shape)
        trail_perform['select'].append(BOARD[max_loc])
        trail_perform['times'].append(num_flashes)

    participant_perform.append(trail_perform)

In [17]:
participant_perform

[{'truth': ['D', 'R', 'I', 'V', 'I', 'N', 'G'],
  'select': ['D', 'V', 'I', 'Q', 'I', 'J', 'C'],
  'times': [38, 70, 30, 80, 54, 28, 61]},
 {'truth': ['Q', 'U', 'I', 'C', 'K', 'L', 'Y'],
  'select': ['R', '3', 'I', 'E', 'J', 'EC', 'Y'],
  'times': [45, 29, 61, 30, 79, 84, 38]},
 {'truth': ['T', 'O', 'W', 'A', 'R', 'D', 'S'],
  'select': ['T', 'O', 'X', 'A', 'R', 'D', 'S'],
  'times': [21, 76, 119, 27, 50, 43, 37]},
 {'truth': ['D', 'A', 'Y', 'L', 'I', 'G', 'H', 'T'],
  'select': ['J', 'I', 'PgUp', 'L', 'I', 'PgDn', 'H', 'T'],
  'times': [64, 18, 14, 39, 12, 33, 69, 36]},
 {'truth': ['5', '1', '4', '9', '7', '3', '6'],
  'select': ['5', '1', '3', '9', '7', 'L', '6'],
  'times': [83, 35, 63, 23, 20, 57, 42]}]