In [None]:
import scipy.io as sio  
import numpy as np
from skimage.transform import resize
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
import keras
from keras.layers import Input, Conv2D, Dense, Flatten, AveragePooling2D, BatchNormalization
from keras import regularizers
from keras.initializers import glorot_uniform
from keras.models import Model
from keras.losses import categorical_crossentropy
from keras import optimizers
from sklearn.model_selection import train_test_split

this function is used for image interpolation

In [None]:
def interpolation(data,bigNum):
    # only for dim is 3
    if data.shape[-1] != 3:
        return None
    m = len(data)
    data_interpolation = np.zeros((m,bigNum,bigNum,3))
    for i in range(m):
        data_interpolation[i] = resize(data[i], (bigNum, bigNum,3), mode='symmetric')
    return data_interpolation

Set parameters

In [None]:
dataDir = './data/'
modelDir = './model'
interpShape = 64
augNum = 1

Load training data

In [None]:
position = 'baseline'
path=dataDir + '/'+position+'/' + '/data_train'
data=sio.loadmat(path)
X_train = data['data_test']

path=dataDir + '/'+position+'/' + '/label_train'
data=sio.loadmat(path)
label_train = data['label_test']

# train test split
X_train, X_test, y_train, y_test = train_test_split(X_train, label_train.reshape(-1), test_size=0.2, random_state=42)



__Image interpolation__

In [None]:
X_train = interpolation(X_train,interpShape)
X_test_baseline = interpolation(X_test,interpShape)

Perform __data augmentation__

In [None]:
width_shift_range = 0.15
height_shift_range = 0.15
datagen = ImageDataGenerator(width_shift_range=width_shift_range, height_shift_range=height_shift_range)
# fit parameters from data
datagen.fit(X_train)
# configure batch size and retrieve one batch of images

#generate images
X_train_aug = []
for i in range(augNum):
    for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=len(X_train),shuffle = False):
        # create a grid of 3x3 images
        print('examples of augmented data')
        for i in range(0, 9):
            pyplot.subplot(330 + 1 + i)
            pyplot.imshow(X_batch[i,:,:,0].reshape(interpShape, interpShape), cmap=pyplot.get_cmap('YlOrRd'))
        # show the plot
        pyplot.show()
        break
    X_train_aug.append(X_batch)
    
X_train_with_aug = X_train
for i in range(augNum):
    X_train_with_aug = np.vstack([X_train_with_aug, X_train_aug[i]])
y_train_with_aug = np.zeros((len(y_train)*(augNum+1),))
for i in range(len(y_train_with_aug)):
    y_train_with_aug[i] = y_train[i%len(y_train)]
    
del X_train, X_train_aug
import gc
gc.collect()

y_train_with_aug_onehot = keras.utils.to_categorical(y_train_with_aug, 6)
y_test_onehot_baseline = keras.utils.to_categorical(y_test, 6)

Built __dilated convolutional neural networks__

In [None]:
def NNclf(input_shape):
    input_img = Input(input_shape) # 50 50
    x = Conv2D(8, (3, 3), activation='relu',dilation_rate=(3, 3), padding='same',name='e1',kernel_regularizer=regularizers.l2(0.001),
              kernel_initializer = glorot_uniform(seed=0))(input_img)
    x = AveragePooling2D((2, 2), padding='same')(x)
    x = Conv2D(8, (3, 3), activation='relu',dilation_rate=(3, 3), padding='same',name='e2',kernel_regularizer=regularizers.l2(0.001),
              kernel_initializer = glorot_uniform(seed=0))(x)
    x = AveragePooling2D((2, 2), padding='same')(x)
    x = Conv2D(8, (3, 3), activation='relu',dilation_rate=(3, 3), padding='same',name='e3',kernel_regularizer=regularizers.l2(0.001),
              kernel_initializer = glorot_uniform(seed=0))(x)
    
    encoded = AveragePooling2D((2, 2), padding='same',name='e4')(x)
    x = Flatten()(encoded)
        
    x = Dense(units=128, activation='relu',name = 'dense_layer1',kernel_regularizer=regularizers.l2(0.001),
              kernel_initializer = glorot_uniform(seed=0))(x)
    output_layer = Dense(units=6, activation='softmax',name = 'output_layer',kernel_initializer = glorot_uniform(seed=0))(x)
    model = Model(inputs=input_img, outputs=output_layer)
    return model

You can directly use the pretrained model for classification.

In [None]:
nnclf = NNclf((interpShape,interpShape,3))
nnclf.load_weights('./model/nnclf.h5',by_name = True)
nnclf.compile(loss=categorical_crossentropy, optimizer=optimizers.Adam(lr=0.001), metrics=['acc'])
print('electrode position: baseline')
model_evaluate = []
model_evaluate.append(nnclf.evaluate(X_test_baseline,y_test_onehot_baseline))
print('model_evaluate',model_evaluate)

positions = ['right_distal','right_proximal','left_distal','left_proximal','random']
for position in positions:
    print('electrode position:'+ position)
    path=dataDir + '/'+position+'/' + 'data_test'
    data=sio.loadmat(path)
    X_test_raw = data['data_test']
    path=dataDir + '/'+position+'/' + '/label_test'
    data=sio.loadmat(path)
    label_test = ((data['label_test']).T)
    y_test_onehot = keras.utils.to_categorical(label_test.reshape(-1,), 6)
    X_test = interpolation(X_test_raw,interpShape)

    model_evaluate = []
    model_evaluate.append(nnclf.evaluate(X_test,y_test_onehot))
    print('model_evaluate',model_evaluate)
    print(' ')

Alternatively, you can run the following code to train your own model.

In [None]:
nnclf = NNclf((interpShape,interpShape,3))
# nnclf.load_weights('./model/' + name + '_aug_autoencoder.h5',by_name = True)
nnclf.summary()
nnclf.compile(loss=categorical_crossentropy, optimizer=optimizers.Adam(lr=0.001), metrics=['acc'])
nnclf.fit(x=X_train_with_aug, y=y_train_with_aug_onehot, batch_size=32, epochs=20, shuffle=True, validation_split=0.05)
nnclf.compile(loss=categorical_crossentropy, optimizer=optimizers.Adam(lr=0.0001), metrics=['acc'])
nnclf.fit(x=X_train_with_aug, y=y_train_with_aug_onehot, batch_size=32, epochs=20, shuffle=True, validation_split=0.05)

In [None]:
# nnclf.save_weights('./model/nnclf.h5')
print('electrode position: baseline')
model_evaluate = []
model_evaluate.append(nnclf.evaluate(X_test_baseline,y_test_onehot_baseline))
print('model_evaluate',model_evaluate)

positions = ['right_distal','right_proximal','left_distal','left_proximal','random']
for position in positions:
    print('electrode position:'+ position)
    path=dataDir + '/'+position+'/' + 'data_test'
    data=sio.loadmat(path)
    X_test_raw = data['data_test']
    path=dataDir + '/'+position+'/' + '/label_test'
    data=sio.loadmat(path)
    label_test = ((data['label_test']).T)
    y_test_onehot = keras.utils.to_categorical(label_test.reshape(-1,), 6)
    X_test = interpolation(X_test_raw,interpShape)

    model_evaluate = []
    model_evaluate.append(nnclf.evaluate(X_test,y_test_onehot))
    print('model_evaluate',model_evaluate)
    print(' ')