In [0]:
!pip install mne

from collections import OrderedDict
import numpy as np
import tensorflow as tf
from datetime import datetime
from sklearn.metrics import cohen_kappa_score
from mne.io import read_raw_edf, find_edf_events
import helper_functions as hf
from load_data import BCICompetition4Set2A
from flip_gradient import flip_gradient

In [0]:
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [0]:
# Data helper functions
test_subject = 1
test_subject -= 1
r = 9

def BCI_read_data(filename, labels_filename):    
    train_loader = BCICompetition4Set2A(filename=filename, labels_filename=labels_filename)
    train_cnt = train_loader.load()
    train_cnt = train_cnt.drop_channels(['STI 014', 'EOG-left', 'EOG-central', 'EOG-right'])
    assert len(train_cnt.ch_names) == 22
    
    train_cnt = hf.mne_apply(lambda a: a * 1e6, train_cnt)
    train_cnt = hf.mne_apply(lambda a: hf.bandpass_cnt(a, 0, 38.0, train_cnt.info['sfreq'],
                               filt_order=3,
                               axis=1), train_cnt)
    train_cnt = hf.mne_apply(lambda a: hf.exponential_running_standardize(a.T, factor_new=1e-3,
                                                  init_block_size=1000,
                                                  eps=1e-4).T, train_cnt)
    
    return train_cnt

def BCI_load_all_data():
    ival = [-500, 4000]
    marker_def = OrderedDict([('Left Hand', [1]), ('Right Hand', [2],),
                              ('Foot', [3]), ('Tongue', [4])])
    
    raw_files = [BCI_read_data("A0" + str(f) + "T.gdf", 
                               "A0" + str(f) + "T.mat") for f in range(1, r+1)]
    
    
    train_set_list = [hf.create_signal_target_from_raw_mne(raw, marker_def, ival) for raw in raw_files]
    
    train_set_X = np.concatenate([train_set_list[i].X for i in  range(r) if i!=test_subject])
    # train_set_X = np.concatenate([train_set.X for train_set in train_set_list[:-1]])
    train_set_y = np.concatenate([train_set_list[i].y for i in  range(r) if i!=test_subject])
    # train_set_y = np.concatenate([train_set.y for train_set in train_set_list[:-1]])
    train_set_z = np.concatenate([np.zeros((288)) + i for i in range(r-1)])
    
    indices = np.arange(train_set_y.shape[0])
    np.random.shuffle(indices)
    
    train_set = hf.SignalAndTarget(train_set_X[indices, :, :], train_set_y[indices], train_set_z[indices])
    test_set = train_set_list[test_subject]
    train_set, valid_set = hf.split_into_two_sets(train_set, first_set_fraction=1-0.2)
    
    iterator = hf.CropsFromTrialsIterator(batch_size=60, input_time_length=1125, n_preds_per_input=4)
    
    return iterator, train_set, valid_set, test_set

In [0]:
iterator, train_set, valid_set, test_set = BCI_load_all_data()

In [0]:
def sinc(band, t_right, width, sinc_0):
  y_right = tf.math.sin(np.pi * 2 * band * t_right) / (np.pi * 2 * band * t_right)
  y_left = tf.reverse(y_right, [1])[:, :width // 2 - 1 + width % 2]
  y = tf.concat([y_left, sinc_0, y_right], 1)
  return y

samples = 1125
num_filters = 100
width = samples // 10 
fs = 250

n = np.arange(0, width)
window = 0.54 - 0.46 * np.cos(2 * np.pi * n / width)
t_right = np.arange(1, width // 2 + 1 - width % 2)
sinc_0 = np.ones([num_filters, 1])

def sincnet_model(x, mode):
    low_f = tf.get_variable("low_f", initializer=np.random.uniform(0, fs/2-fs/4, [num_filters, 1]))
    band_w = tf.get_variable("band_w", initializer=np.random.uniform(0, fs/4, [num_filters, 1]))

    high_f = tf.math.abs(low_f + tf.math.abs(band_w))
    low_pass1 = 2 * low_f * sinc(low_f, t_right, width, sinc_0)
    low_pass2 = 2 * (high_f + band_w) * sinc(high_f, t_right, width, sinc_0)
    band_pass = (low_pass2 - low_pass1)
    band_pass = band_pass / tf.math.reduce_max(band_pass)

    filters = band_pass * window
    filters = tf.transpose(filters)
    filters = tf.reshape(filters, (1, filters.shape[0], 1, filters.shape[1]))
    filters = tf.dtypes.cast(filters, tf.float32)
    
    conv1 = tf.layers.conv2d(x, int(22*1.5), (22, 1))
    #dropout1 = tf.layers.dropout(conv1, 0.5, training=mode)
    batch1 = tf.layers.batch_normalization(conv1, momentum=0.1)
    perm1 = tf.transpose(batch1, perm=[0, 3, 2, 1])
    activation1 = tf.nn.leaky_relu(perm1)
    
    features = tf.nn.conv2d(activation1, filters, strides=[1, 1, 1, 1], padding="VALID")
    pool = tf.layers.max_pooling2d(inputs=features, pool_size=(12, 13), strides=(1, 1))
    batch = tf.layers.batch_normalization(pool, momentum=0.1)
    activation = tf.nn.leaky_relu(batch)
    #dropout = tf.layers.dropout(activation, 0.5, training=mode)
        
    return activation

In [0]:
def detect_subject_model(features, mode):
    conv1 = tf.layers.conv2d(inputs=features, filters = 40, kernel_size=(1, 25), dilation_rate=(1, 3), strides=1, activation=None)
    
    conv2 = tf.layers.conv2d(inputs=conv1,filters = 40, kernel_size=(18, 1), dilation_rate=(1, 1), strides=1, activation=None)
    batch2 = tf.layers.batch_normalization(conv2, momentum=0.1)
    activation2 = tf.nn.elu(batch2)
    pool2 = tf.layers.max_pooling2d(inputs=activation2, pool_size=(1, 60), strides=5)
    dropout2 = tf.layers.dropout(pool2, 0.5, training=mode)
          
    conv3 = tf.layers.conv2d(inputs=dropout2, filters = 8, kernel_size=(1, 30), dilation_rate=(1, 6), strides=1, activation=None)
    logits = tf.reshape(conv3, [-1, 8])

    return logits, features
  
def shallow_brain_model(features, sinc_feat, mode):
    conv0 = tf.layers.conv2d(inputs=features, filters = num_filters, kernel_size=(1, 124), strides=1, activation=tf.nn.relu)
    
    if sinc_feat is not None:
        features = conv0 - sinc_feat
    
    conv1 = tf.layers.conv2d(inputs=features, filters = 40, kernel_size=(1, 25), dilation_rate=(1, 3), strides=1, activation=None)
    
    conv2 = tf.layers.conv2d(inputs=conv1, filters = 40, kernel_size=(18, 1), dilation_rate=(1, 1), strides=1, activation=None)
    batch2 = tf.layers.batch_normalization(conv2, momentum=0.1)
    activation2 = tf.nn.elu(batch2)
    pool2 = tf.layers.max_pooling2d(inputs=activation2, pool_size=(1, 60), strides=5)
    dropout2 = tf.layers.dropout(pool2, 0.5, training=mode)
    
    conv3 = tf.layers.conv2d(inputs=dropout2, filters = 4, kernel_size=(1, 30), dilation_rate=(1, 6), strides=1, activation=None)
    logits = tf.reshape(conv3, [-1, 4])
    
    return logits

In [0]:
tf.reset_default_graph()

mode = tf.placeholder(tf.bool)
# mode = True
# X = tf.placeholder(tf.float32, shape=[60, 22, 1125, 1])
X = tf.placeholder(tf.float32, shape=[None, None, None, 1])
# Y= tf.placeholder(tf.float32, shape=[60, 22, 1002, 100])
Y = tf.placeholder(tf.float32, shape=[None, None, None, num_filters])
labels = tf.placeholder(tf.int32, shape=[None])
lr = tf.placeholder(tf.float32)

sincnet = sincnet_model(X, mode)
subject_model = detect_subject_model(sincnet, mode)

class_model = shallow_brain_model(X, Y, mode)

print(np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()]))

In [0]:
subject_loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=subject_model[0])
class_loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=class_model) 

subject_loss_summ = tf.summary.scalar("Subject_loss", subject_loss)
class_loss_summ = tf.summary.scalar("Class_loss", class_loss)

class_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
subject_optimizer = tf.train.AdamOptimizer(learning_rate=lr)

subject_op = subject_optimizer.minimize(loss=subject_loss, global_step=tf.train.get_global_step())
class_op = class_optimizer.minimize(loss=class_loss)

In [0]:
epoch_number = 35
max_test_acc = 0

with tf.Session() as sess:
  
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    
    print("{} Start training...".format(datetime.now()))
    
    for epoch in range(epoch_number):
      
        all_pred_subject = []
        all_loss_subject = []
        all_subjects = []
        val_pred_subject = []
        val_loss_subject = []
        val_subjects = []
        test_pred_subject = []
        
        
        all_pred_class = []
        all_loss_class = []
        all_targets = []
        val_pred_class = []
        val_loss_class = []
        val_targets = []
        test_pred_class = []
        test_targets = []
      
        batch_generator = iterator.get_batches(train_set, shuffle=True)
        batch_generator_2 = iterator.get_batches(valid_set, shuffle=True)
        batch_generator_3 = iterator.get_batches(test_set, shuffle=True)
        
        for inputs, targets, subjects in batch_generator:
            _, s_loss, subject_out = sess.run([subject_op, subject_loss, subject_model], feed_dict={X: inputs, mode:True, lr:0.0005, labels:subjects})

            subject_logits = subject_out[0]
            sinc_feat = subject_out[1]
            
            prediction = np.argmax(subject_logits, 1)
            all_pred_subject.extend(prediction)
            all_loss_subject.append(s_loss)
            all_subjects.extend(subjects)
            
            _, c_loss, class_logits = sess.run([class_op, class_loss, class_model], feed_dict={X: inputs, Y: sinc_feat, mode:True, lr:0.001, labels:targets})

  
            prediction = np.argmax(class_logits, 1)
            all_pred_class.extend(prediction)
            all_loss_class.append(c_loss)
            all_targets.extend(targets)
            
        
        subject_accuracy = np.sum(np.equal(all_pred_subject, all_subjects)) / float(len(all_pred_subject))
        class_accuracy = np.sum(np.equal(all_pred_class, all_targets)) / float(len(all_pred_class))

        num_nodes = len([n.name for n in tf.get_default_graph().as_graph_def().node])
    
        print("Epoch {} nodes {}".format(epoch, num_nodes))
        print("Subjects: Loss {:.6f}, Accuracy: {}, Kappa: {}".format(np.mean(all_loss_subject), subject_accuracy, cohen_kappa_score(all_pred_subject, all_subjects)))
        print("Class:    Loss {:.6f}, Accuracy: {}, Kappa: {}".format(np.mean(all_loss_class), class_accuracy, cohen_kappa_score(all_pred_class, all_targets)))

        
        
        for inputs, targets, subjects in batch_generator_2:
            s_loss, subject_out = sess.run([subject_loss, subject_model], feed_dict={X: inputs, mode:False, labels:subjects})

            subject_logits = subject_out[0]
            sinc_feat = subject_out[1]
            
            prediction = np.argmax(subject_logits, 1)
            val_pred_subject.extend(prediction)
            val_loss_subject.append(s_loss)
            val_subjects.extend(subjects)

            c_loss, class_logits = sess.run([class_loss, class_model], feed_dict={X: inputs, Y: sinc_feat, mode:False, labels:targets})

            prediction = np.argmax(class_logits, 1)
            val_pred_class.extend(prediction)
            val_loss_class.append(c_loss)
            val_targets.extend(targets)

        class_accuracy = np.sum(np.equal(val_pred_class, val_targets)) / float(len(val_pred_class))
        subject_accuracy = np.sum(np.equal(val_pred_subject, val_subjects)) / float(len(val_pred_subject))

        print("Subjects: Loss {:.6f}, Accuracy: {}, Kappa: {}".format(np.mean(val_loss_subject), subject_accuracy, cohen_kappa_score(val_pred_subject, val_subjects)))
        print("Class:    Loss {:.6f}, Accuracy: {}, Kappa: {}".format(np.mean(val_loss_class), class_accuracy, cohen_kappa_score(val_pred_class, val_targets)))
        
        
        
        for inputs, targets, subjects in batch_generator_3:
            sinc_feat = sess.run(sincnet, feed_dict={X: inputs, mode:False})
            
            c_loss, class_logits = sess.run([class_loss, class_model], feed_dict={X: inputs, Y: sinc_feat, mode:False, labels:targets})

            prediction = np.argmax(class_logits, 1)
            test_pred_class.extend(prediction)
            test_targets.extend(targets)

        class_accuracy = np.sum(np.equal(test_pred_class, test_targets)) / float(len(test_pred_class))
            
          
        if class_accuracy > max_test_acc:
            max_test_acc = class_accuracy
        
        print("TestC:    Loss {:.6f}, Accuracy: {}, max_test_acc: {}, Kappa: {}".format(c_loss, class_accuracy, max_test_acc, cohen_kappa_score(test_pred_class, test_targets)))

#         Save checkpoints
#         saver.save(sess, "/content/gdrive/My Drive/checkpoints/sincnet_reversal/" + str(test_subject+1) + "/" + str(epoch) + "/model.ckpt" )
        
    print("{} Done training...".format(datetime.now()))


