In [1]:
#https://www.tensorflow.org/alpha/tutorials/keras/basic_classification/
#https://arxiv.org/pdf/1703.05051.pdf %for cnn_eeg

In [2]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

1.13.1


In [3]:
# Creates a graph.
with tf.device('/cpu:0'):
  a = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3], name='a')
  b = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], name='b')
c = tf.matmul(a, b)
# Creates a session with log_device_placement set to True.
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
# Runs the op.
print(sess.run(c))

[[22. 28.]
 [49. 64.]]


In [4]:
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 1168088331858344299
, name: "/device:XLA_CPU:0"
device_type: "XLA_CPU"
memory_limit: 17179869184
locality {
}
incarnation: 12805910732115739980
physical_device_desc: "device: XLA_CPU device"
]


In [5]:
!pip3 install braindecode



In [6]:
import logging
import importlib
importlib.reload(logging) # see https://stackoverflow.com/a/21475297/1469195
log = logging.getLogger()
log.setLevel('INFO')
import sys
logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s',
                     level=logging.INFO, stream=sys.stdout)

In [7]:
import mne
import numpy as np
from mne.io import concatenate_raws
from braindecode.datautil.signal_target import SignalAndTarget

# First 50 subjects as train
physionet_paths = [ mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(1,51)]
physionet_paths = np.concatenate(physionet_paths)
parts = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths]

raw = concatenate_raws(parts)

picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

events = mne.find_events(raw, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks,
                baseline=None, preload=True)

# 51-55 as validation subjects
physionet_paths_valid = [mne.datasets.eegbci.load_data(sub_id,[4,8,12,]) for sub_id in range(51,56)]
physionet_paths_valid = np.concatenate(physionet_paths_valid)
parts_valid = [mne.io.read_raw_edf(path, preload=True,stim_channel='auto')
         for path in physionet_paths_valid]
raw_valid = concatenate_raws(parts_valid)

picks_valid = mne.pick_types(raw_valid.info, meg=False, eeg=True, stim=False, eog=False,
                   exclude='bads')

events_valid = mne.find_events(raw_valid, shortest_event=0, stim_channel='STI 014')

# Read epochs (train will be done only between 1 and 2s)
# Testing will be done with a running classifier
epoched_valid = mne.Epochs(raw_valid, events_valid, dict(hands=2, feet=3), tmin=1, tmax=4.1, proj=False, picks=picks_valid,
                baseline=None, preload=True)

train_X = (epoched.get_data() * 1e6).astype(np.float32)
train_y = (epoched.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
valid_X = (epoched_valid.get_data() * 1e6).astype(np.float32)
valid_y = (epoched_valid.events[:,2] - 2).astype(np.int64) #2,3 -> 0,1
#train_set = SignalAndTarget(train_X, y=train_y)
#valid_set = SignalAndTarget(valid_X, y=valid_y)

Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R08.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S001/S001R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel inf

EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S008/S008R08.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S008/S008R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: 

Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S015/S015R08.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S015/S015R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S016/S016R04.edf...
EDF file de

EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S022/S022R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S023/S023R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotati

Creating raw.info structure...
Reading 0 ... 19839  =      0.000 ...   123.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S029/S029R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19839  =      0.000 ...   123.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S030/S030R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19839  =      0.000 ...   123.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/

Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S036/S036R12.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S037/S037R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19839  =      0.000 ...   123.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S037/S037R08.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel inf

EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S044/S044R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T2', 'T1']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S044/S044R08.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: 

Reading 0 ... 19679  =      0.000 ...   122.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Trigger channel has a non-zero initial value of 1 (consider using initial_event=True to detect this event)
Removing orphaned offset at the beginning of the file.
3610 events found
Event IDs: [1 2 3]
2250 matching events found
No baseline correction applied
Not setting metadata
Loading data for 2250 events and 497 original time points ...
10 bad epochs dropped
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S051/S051R04.edf...
EDF file detected
EDF annotations detected (consider using raw.find_edf_events() to extract them)
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19839  =      0.000 ...   123.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /home/mjd/mne_data/MNE-eegbci-data/physiobank/database/eegmmidb/S051/S051R08.edf...
EDF file detected
EDF annotations

In [8]:
train_X.shape

(2240, 64, 497)

In [9]:
train_y.shape

(2240,)

In [10]:
train_X=train_X.transpose(0,2,1)
train_X=np.reshape(train_X,[-1,497,64,1])

valid_X=valid_X.transpose(0,2,1)
valid_X=np.reshape(valid_X,[-1,497,64,1])

train_X.shape

(2240, 497, 64, 1)

In [11]:
train_X= tf.keras.utils.normalize(train_X, axis=1)
valid_X= tf.keras.utils.normalize(valid_X, axis=1)

In [12]:
#Build the model
n_ch=64
model = keras.Sequential([
    #conv_pool_block_1
    keras.layers.Conv2D(filters=25, kernel_size=(10,1),strides=(1, 1), padding='valid', activation=None, input_shape=(497,64,1)),

    keras.layers.Conv2D(filters=25, kernel_size=(1,n_ch),strides=(1, 1), padding='valid', activation=None),
    keras.layers.MaxPool2D(pool_size=(3,1)),
    keras.layers.Dropout(0.3),

    #conv_pool_block_2
    keras.layers.Conv2D(filters=50, kernel_size=(10,1),strides=(1, 1), padding='valid', activation=None),
    keras.layers.MaxPool2D(pool_size=(3,1)),
    keras.layers.Dropout(0.3),
    
    #conv_pool_block_3
    keras.layers.Conv2D(filters=100, kernel_size=(10,1),strides=(1, 1), padding='valid', activation=None),
    keras.layers.MaxPool2D(pool_size=(3,1)),
    keras.layers.Dropout(0.3),
    
    #conv_pool_block_4
    keras.layers.Conv2D(filters=200, kernel_size=(10,1),strides=(1, 1), padding='valid', activation=None),
    keras.layers.MaxPool2D(pool_size=(3,1)),
    keras.layers.Dropout(0.3),

    #classification Layer
    keras.layers.Flatten(),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, activation='sigmoid')
    ])

# Take a look at the model summary
model.summary()

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 488, 64, 25)       275       
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 488, 1, 25)        40025     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 162, 1, 25)        0         
_________________________________________________________________
dropout (Dropout)            (None, 162, 1, 25)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 153, 1, 50)        12550     
_________________________________________________________________
max_pooling2d_1 (MaxP

In [13]:
adam_my=keras.optimizers.Adam(lr=0.0625 * 0.01, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.01, amsgrad=False)


In [14]:
#Compile the model
model.compile(optimizer=adam_my,
              loss='binary_crossentropy',
              metrics=['accuracy'])

In [15]:
#Train the model
model.fit(train_X, train_y, batch_size=128, epochs=30)

Instructions for updating:
Use tf.cast instead.
Epoch 1/30
Epoch 2/30
Epoch 3/30
Epoch 4/30
Epoch 5/30
Epoch 6/30
Epoch 7/30
Epoch 8/30
Epoch 9/30
Epoch 10/30
Epoch 11/30
Epoch 12/30
Epoch 13/30
Epoch 14/30
Epoch 15/30
Epoch 16/30
Epoch 17/30
Epoch 18/30
Epoch 19/30
Epoch 20/30
Epoch 21/30
Epoch 22/30
Epoch 23/30
Epoch 24/30
Epoch 25/30
Epoch 26/30
Epoch 27/30
Epoch 28/30
Epoch 29/30
Epoch 30/30


<tensorflow.python.keras.callbacks.History at 0x7f98eb3b4978>

In [16]:
#Evaluate accuracy
test_loss, test_acc = model.evaluate(valid_X,valid_y)

print('\nTest accuracy:', test_acc)


Test accuracy: 0.81696427


In [17]:
#Make predictions
predictions = model.predict(valid_X)

In [22]:
predictions[0:10]>0.5

array([[False],
       [ True],
       [ True],
       [False],
       [ True],
       [False],
       [ True],
       [False],
       [ True],
       [False]])

In [21]:
valid_y[0:10]

array([0, 1, 1, 0, 1, 0, 1, 0, 1, 0])