# Build a CNN classifier for P300 speller

<div style="text-align:justify; width: 97%">
Main reference: *CNN With Large Data Achieves True Zero-Training in Online P300 Brain-Computer Interface* by J. Lee et al. (2020)
</div>

In [89]:
# import necessary packages
## Python standard libraries
import math
import random
import os

## Packages for computation and modelling
import numpy as np
import pandas as pd
import scipy.stats as stats
from scipy.stats import norm
import torch
from torch.utils.data import DataLoader, TensorDataset
from torchinfo import summary
# from torcheeg.models import EEGNet
import mne
import pickle

## Packages for visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Self-defined packages
from swlda import SWLDA
import eegnet_utils
import swlda_utils

# Magic command to reload packages whenever we run any later cells
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cpu


## Step 1. Data Preprocessing

In [3]:
BOARD = [["A",    "B",  "C",   "D",    "E",    "F",     "G",    "H"    ],
         ["I",    "J",  "K",   "L",    "M",    "N",     "O",    "P"    ],
         ["Q",    "R",  "S",   "T",    "U",    "V",     "W",    "X"    ],
         ["Y",    "Z",  "Sp",  "1",    "2",    "3",     "4",    "5"    ],
         ["6",    "7",  "8",   "9",    "0",    "Prd",   "Ret",  "Bs"   ],
         ["?",    ",",  ";",   "\\",   "/",    "+",     "-",    "Alt"  ],
         ["Ctrl", "=",  "Del", "Home", "UpAw", "End",   "PgUp", "Shft" ],
         ["Save", "'",  "F2",  "LfAw", "DnAw", "RtAw",  "PgDn", "Pause"],
         ["Caps", "F5", "Tab", "EC",   "Esc",  "email", "!",    "Sleep"]]
BOARD  = np.array(BOARD)
N_ROWS = BOARD.shape[0]  # number of rows
N_COLS = BOARD.shape[1]  # number of columns
M = N_ROWS * N_COLS      # the number of chars on the board

In [27]:
paradigm        = 'RC' # display paradigm ('RC', 'CB', or 'RD')
NUM_TIMESTAMPS  = 195  # number of timestamps in each window to record signals
NUM_CHANNELS    = 32   # number of eletrode channels
EPOCH_SIZE      = 195  # size of epoch, we don't aggregrate this time
NUM_TRAIN_WORDS = 5    # number of training words for one participant
NUM_TEST_WORDS  = 5    # number of testing words for one participant
obj_indices     = ['01', '02', '03', '04', '05', '06', '07',
                   '09', '14', '15', '16', '17', '19']

In [10]:
train_X_list, train_Y_list, test_X_list, test_Y_list = [], [], [], []
for obj in obj_indices:
    directory = os.path.abspath('../../..') + '/BCI_data/EDFData-StudyA'
    obj_directory = directory + f'/A{obj}/SE001'

    train_features,train_response = eegnet_utils.load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CHANNELS,
                                              type=paradigm,
                                              mode='train',
                                              num_words=NUM_TRAIN_WORDS)
    train_X_list.append(train_features)
    train_Y_list.append(train_response.reshape((-1, 1)))

    test_features,test_response   = eegnet_utils.load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CHANNELS,
                                              type=paradigm,
                                              mode='test',
                                              num_words=NUM_TEST_WORDS)
    test_X_list.append(test_features)
    test_Y_list.append(test_response.reshape((-1, 1)))

In [11]:
train_X = torch.from_numpy(np.vstack(train_X_list))
train_Y = torch.from_numpy(np.vstack(train_Y_list))
test_X  = torch.from_numpy(np.vstack(test_X_list))
test_Y  = torch.from_numpy(np.vstack(test_Y_list))

# Resize the features and responses to fit the EEGNet model
train_X = torch.from_numpy(np.expand_dims(train_X, axis=1)) # [55692, 1, 32, 206]
test_X  = torch.from_numpy(np.expand_dims(test_X,  axis=1)) # [55692, 1, 32, 206]

train_Y = np.squeeze(train_Y) # [55692]
test_Y  = np.squeeze(test_Y) # [55692]

In [15]:
print(train_X.shape, train_Y.shape)
print(test_X.shape,  test_Y.shape)

torch.Size([55692, 1, 32, 206]) torch.Size([55692])
torch.Size([55692, 1, 32, 206]) torch.Size([55692])


In [None]:
BATCH_SIZE = 32
LR = 0.001
EPOCHS = 1

trainset, testset = TensorDataset(train_X, train_Y), TensorDataset(test_X, test_Y)
trainloader = DataLoader(dataset=trainset, batch_size=BATCH_SIZE, shuffle=True)
testloader  = DataLoader(dataset=testset,  batch_size=BATCH_SIZE, shuffle=False)

## Train and test the EEGNet classifier

In [None]:
eegnet = eegnet_utils.EEGNet().to(device)
model = eegnet_utils.Model(eegnet, lr=LR)
history = model.fit(trainloader=trainloader, validloader=testloader,
                    epochs=EPOCHS, monitor=["acc", "val_acc"])
eegnet_utils.plot_acc_and_loss(history=history)

In [25]:
history

{'loss': array([0.20812429]),
 'acc': array([0.86949652]),
 'val_loss': array([0.35603786]),
 'val_acc': array([0.86976586])}

In [93]:
obj_indices     = ['01', '02', '03', '04', '05', '06', '07',
                   '09', '14', '15', '16', '17', '19']
eegnet = eegnet_utils.EEGNet().to(device)
model_path = './model/best_eegnet_model.pt'
model = eegnet_utils.Model(eegnet)
model.load(filepath=model_path)

eegnet_clf_test_accs = []
for obj in obj_indices:
    directory = os.path.abspath('../..') + '/BCI_data/EDFData-StudyA'
    obj_directory = directory + f'/A{obj}/SE001'

    test_features,test_response   = eegnet_utils.load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CHANNELS,
                                              type=paradigm,
                                              mode='test',
                                              num_words=NUM_TEST_WORDS)
    test_X  = torch.from_numpy(test_features)
    test_Y  = torch.from_numpy(test_response)
    test_X  = torch.from_numpy(np.expand_dims(test_X,  axis=1))
    test_Y  = np.squeeze(test_Y)

    testset = TensorDataset(test_X, test_Y)
    BATCH_SIZE = 1
    dataloader = DataLoader(dataset=testset, batch_size=BATCH_SIZE, shuffle=False)
    test_loss, test_acc = model.evaluate(dataloader)
    eegnet_clf_test_accs.append(test_acc)

In [90]:
summary(model.model, input_size=(4284, 1, 32, 206))

Layer (type:depth-idx)                   Output Shape              Param #
EEGNet                                   [4284, 2]                 --
├─Sequential: 1-1                        [4284, 64, 32, 207]       --
│    └─Conv2d: 2-1                       [4284, 64, 32, 207]       4,096
│    └─BatchNorm2d: 2-2                  [4284, 64, 32, 207]       128
├─Sequential: 1-2                        [4284, 256, 1, 51]        --
│    └─Conv2d: 2-3                       [4284, 256, 1, 207]       8,192
│    └─BatchNorm2d: 2-4                  [4284, 256, 1, 207]       512
│    └─ELU: 2-5                          [4284, 256, 1, 207]       --
│    └─AvgPool2d: 2-6                    [4284, 256, 1, 51]        --
│    └─Dropout: 2-7                      [4284, 256, 1, 51]        --
├─Sequential: 1-3                        [4284, 256, 1, 6]         --
│    └─Conv2d: 2-8                       [4284, 256, 1, 52]        4,096
│    └─Conv2d: 2-9                       [4284, 256, 1, 52]        65,536


In [78]:
eegnet_clf_test_accs

[0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706,
 0.8823529411764706]

In [50]:
eegnet_clf_test_accs

[0.8823529411764706]

## Train and test the SWLDA classifier (benchmark)

In [81]:
obj_indices = ['01', '02', '03', '04', '05', '06', '07',
               '09', '14', '15', '16', '17', '19']
CORE_CHANNELS = ('EEG_Fz', 'EEG_Cz',  'EEG_P3',  'EEG_Pz',
                 'EEG_P4', 'EEG_PO7', 'EEG_PO8', 'EEG_Oz')
NUM_CORE_CHANNELS = 8

swlda_clf_test_accs = []

for obj in obj_indices:
    directory = os.path.abspath('../..') + '/BCI_data/EDFData-StudyA'
    obj_directory = directory + f'/A{obj}/SE001'

    train_features,train_response = swlda_utils.load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CORE_CHANNELS,
                                              type=paradigm,
                                              mode='train',
                                              num_words=NUM_TRAIN_WORDS)

    test_features,test_response   = swlda_utils.load_data(dir=obj_directory,
                                              obj=obj,
                                              num_timestamps=NUM_TIMESTAMPS,
                                              epoch_size=EPOCH_SIZE,
                                              num_channels=NUM_CORE_CHANNELS,
                                              type=paradigm,
                                              mode='test',
                                              num_words=NUM_TEST_WORDS)

    try:
        f = open(f'./model/A{obj}-swlda-model.pkl', 'rb')
        clf = pickle.load(f)
    except:
        clf = SWLDA(penter=0.1, premove=0.15)
        clf.fit(train_features, train_response)
        # save the classifier as a standalone model file
        with open(f'./model/A{obj}-swlda-model.pkl','wb') as f:
            pickle.dump(clf,f)

    pred = clf.test(test_features) > 0.5
    test_acc = sum(pred == test_response) / len(pred)
    swlda_clf_test_accs.append(test_acc)

In [82]:
swlda_clf_test_accs

[0.8828197945845004,
 0.8895891690009337,
 0.8839869281045751,
 0.8930905695611578,
 0.8828197945845004,
 0.892623716153128,
 0.888422035480859,
 0.8830532212885154,
 0.8825863678804855,
 0.8821195144724556,
 0.8818860877684407,
 0.8823529411764706,
 0.8823529411764706]

In [19]:
sum(swlda_clf_test_accs) / len(swlda_clf_test_accs)

0.8852079293255765