In [79]:
# Import all relevant libraries
import warnings
def fxn(): 
	warnings.warn("deprecated",DeprecationWarning)

with warnings.catch_warnings( ):
    warnings.simplefilter("ignore")
    fxn( )

# Keras imports
import keras
from keras.models import Sequential
from keras.layers import Permute, Flatten, Softmax, Dense, Conv1D, Conv2D, Conv2DTranspose, AveragePooling2D, Activation, Reshape, Dropout

# Other
import numpy as np
import h5py
import sklearn

In [85]:
# Load data from specific trial
def get_trial(trial_num):    
    trial = h5py.File('../data/A0' + str(trial_num) + 'T_slice.mat', 'r')
    X = np.copy(trial['image'])
    y = np.copy(trial['type'])
    y = y[0,0:X.shape[0]:1]
    y = np.asarray(y, dtype=np.int32)
    y -= 769                            # shift class labels to [0-3]
    X = np.nan_to_num(X)[:, :22, :]     # remove EOG channels
    return X, y

def get_all_trials():
    X_total = np.concatenate([get_trial(trial_num)[0] for trial_num in range(1, 9)], axis=0)
    y_total = np.concatenate([get_trial(trial_num)[1] for trial_num in range(1, 9)], axis=0)
    return X_total, y_total

def stratified_train_test_split(X, y, k):
    ''' Returns a stratified train/test split, for k number of splits.
    Return value is in the form [(train indices, test indices), ... for k folds ]
    '''
    skf = sklearn.model_selection.StratifiedKFold(n_splits=k)
    return skf.split(X, y)

(2304, 22, 1000) (2304,)


In [62]:
# Get the data from the first person
X, y = get_trial(1)

X_train = X
y_train = keras.utils.to_categorical(y, num_classes=4)
print(X_train.shape, y_train.shape)

# The data for each trial is of the shape (288, 22, 1000)
#   There are 288 samples per trial (12 of each class per "run", 4 classes, 6 "runs" 
#                                   at different time periods of the day)
#   There are 22 electrodes from the EEG (represents spatial aspect of the signals)
#   There are 1000 time units (4 seconds of data, sampled at 250Hz). The first 250 units
#                                   are when no movement occurs (but the cue is heard) and
#                                   the next 750 units are when the movement occurs
# The labels for each trial belong in one of 4 classes
#   0 - left
#   1 - right
#   2 - foot
#   3 - tongue

[3 2 1 0 0 1 2 3 1 2 0 0 0 3 1 1 0 0 2 0 1 3 3 2 0 3 3 1 3 3 1 0 1 2 2 2 3
 2 0 3 1 2 1 2 3 1 2 0 0 0 3 1 0 2 0 2 1 3 0 2 2 0 2 1 3 3 3 2 0 3 1 3 1 0
 2 1 0 2 2 0 2 3 3 1 0 1 3 1 3 2 1 1 1 2 3 0 1 3 0 2 2 3 0 0 2 1 3 3 3 1 0
 2 1 3 0 3 2 1 3 3 0 1 1 2 3 1 0 0 3 1 0 2 1 1 2 0 3 2 2 2 2 0 1 0 1 0 0 2
 2 1 2 3 0 3 0 0 1 3 2 1 3 2 3 2 3 1 1 3 0 1 1 1 2 3 0 3 0 2 0 3 0 2 0 1 2
 2 3 0 1 3 1 2 2 0 3 1 3 0 0 2 2 1 3 1 1 0 1 3 3 1 1 1 1 3 3 2 3 0 1 2 1 0
 3 0 3 0 0 0 0 2 2 3 1 2 2 2 3 2 0 2 0 3 1 3 3 2 3 3 2 1 3 2 0 1 1 1 2 1 3
 2 3 1 2 0 3 0 2 3 0 2 0 1 1 0 3 0 3 2 2 0 2 1 1 0 2 0 1 0]
(288, 22, 1000) (288, 4)


In [71]:
print(y_train)

[[0. 0. 0. 1.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 ...
 [1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [1. 0. 0. 0.]]


In [82]:
# Create CNN model

# input is of the form: (sample, spatial, temporal)
model = Sequential()

# Temporal convolution
model.add(Reshape((22, 1000, 1), input_shape=(22, 1000)))
model.add(Conv2D(filters=40, kernel_size=(1, 25), activation='elu', strides=1))
print(model.output_shape)

# Spatial convolution
model.add(Conv2D(filters=40, kernel_size=(22, 40), activation='elu', data_format="channels_first"))
print(model.output_shape)

# Mean pool
model.add(AveragePooling2D(pool_size=(1,75), strides=(1,15)))
print(model.output_shape)

# Dense layers
model.add(Flatten())
model.add(Dense(units=400, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=200, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(units=4, activation='softmax'))
print(model.output_shape)

model.compile(loss='categorical_crossentropy',
              optimizer='SGD',
              metrics=['accuracy'])

(None, 22, 976, 40)
(None, 40, 955, 1)
(None, 40, 59, 1)
(None, 4)


In [83]:
model.fit(X_train, y_train, epochs=3, batch_size=32)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x13fe89940>