# Malafeev CNN reference method

Malafeev, A., Hertig-Godeschalk, A., Schreier, D. R., Skorucak, J., Mathis, J., & Achermann, P. (2021). Automatic Detection of Microsleep Episodes With Deep Learning. Frontiers in Neuroscience, 15, 564098. https://doi.org/10.3389/fnins.2021.564098

This is a reference method with slightly modified implemenations

## Build model

In [1]:
# Build 16-CNN Model

import os

from keras import backend as K
from keras import optimizers

from keras.callbacks import History 
from keras.layers import concatenate
from keras.layers import Layer,Dense, Dropout, Input, Activation, TimeDistributed, Reshape
from keras.layers import  GRU, Bidirectional
from keras.layers import  Conv1D, Conv2D, MaxPooling2D, Flatten, BatchNormalization, LSTM, ZeroPadding2D, GlobalAveragePooling2D, SpatialDropout2D
from keras.layers.noise import GaussianNoise
from keras.models import Sequential
from keras.models import Model
from keras.preprocessing import sequence
from keras.utils import np_utils


def build_model(data_dim, n_channels, n_cl):
	eeg_channels = 1
	act_conv = 'relu'
	init_conv = 'glorot_normal'
	dp_conv = 0.3
	def cnn_block(input_shape):
		input = Input(shape=input_shape)
		x = GaussianNoise(0.0005)(input)
		x = Conv2D(32, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
		x = BatchNormalization()(x)
		x = Activation(act_conv)(x)
		x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		
		
		x = Conv2D(64, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
		x = BatchNormalization()(x)
		x = Activation(act_conv)(x)
		x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		for i in range(4):
			x = Conv2D(128, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
			x = BatchNormalization()(x)
			x = Activation(act_conv)(x)
			x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		for i in range(6):
			x = Conv2D(256, (3, 1), strides=(1, 1), padding='same', kernel_initializer=init_conv)(x)
			x = BatchNormalization()(x)
			x = Activation(act_conv)(x)
			x = MaxPooling2D(pool_size=(2, 1), padding='same')(x)
		flatten1 = Flatten()(x)
		cnn_eeg = Model(inputs=input, outputs=flatten1)
		return cnn_eeg
		
	hidden_units1  = 256
	dp_dense = 0.5

	eeg_channels = 1
	eog_channels = 2

	input_eeg = Input(shape=( data_dim, 1,  3))
	cnn_eeg = cnn_block(( data_dim, 1, 3))
	x_eeg = cnn_eeg(input_eeg)
	x = BatchNormalization()(x_eeg)
	x = Dropout(dp_dense)(x)
	x =  Dense(units=hidden_units1, activation=act_conv, kernel_initializer=init_conv)(x)
	x = BatchNormalization()(x)
	x = Dropout(dp_dense)(x)

	predictions = Dense(units=n_cl, activation='softmax', kernel_initializer=init_conv)(x)

	model = Model(inputs=[input_eeg] , outputs=[predictions])
	return [cnn_eeg, model]

cnn, model = build_model(3200, 3, 2)
model.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 3200, 1, 3)]      0         
                                                                 
 model (Functional)          (None, 256)               1270528   
                                                                 
 batch_normalization_12 (Bat  (None, 256)              1024      
 chNormalization)                                                
                                                                 
 dropout (Dropout)           (None, 256)               0         
                                                                 
 dense (Dense)               (None, 256)               65792     
                                                                 
 batch_normalization_13 (Bat  (None, 256)              1024      
 chNormalization)                                          

## Load and pre-process data

In [2]:
import json

# Load some random subjects from train splits
with open("./splits/skorucack_splits.json") as f:
    splits = json.loads(f.read())

In [3]:

import numpy as np
from scipy.io import loadmat
from keras.utils.np_utils import to_categorical

class MalafeevStudy(object):

    fs = 200
    win_len = int(16 * fs)
    pad_len = int(win_len / 2)

    def __init__(self, signal_file, target_file):
        
        tmp = loadmat(signal_file, struct_as_record=False, squeeze_me=True)
        
        pre_process = lambda x: np.clip(x / 100, -1, 1)
        padding = lambda x: np.pad(x, pad_width=(self.pad_len, self.pad_len), mode="constant", constant_values = (0, 0))

        self.E1 = np.expand_dims(padding(pre_process(tmp['Data'].E1)),-1)
        self.E2 = np.expand_dims(padding(pre_process(tmp['Data'].E2)), -1)
        O1 = np.expand_dims(padding(pre_process(tmp['Data'].eeg_O1)), -1)
        O2 = np.expand_dims(padding(pre_process(tmp['Data'].eeg_O2)), -1)
        self.EEG = [O1, O2]

        targets = loadmat(target_file)['x']
        targets[targets!=1] = 0
        self.y = to_categorical(targets, num_classes=2) 

        self.num_win = len(self.y)


    def get_window_by_idx(self, i, c):
        
        
        start = int(i*self.fs)
        end = int(start+self.win_len)

        x = np.concatenate([self.E1[start:end], self.E2[start:end], self.EEG[c][start:end]],axis=1)
        x = np.expand_dims(x, -2)

        y = self.y[i]

        return x, y

    def get_sample_idx(self, study_idx):

        sample_idx = np.empty([self.num_win * 2, 3], dtype=int)
        for c in range(2):
            for i in range(self.num_win):
                ix=int(c*self.num_win+i)
                sample_idx[ix,...]=(study_idx, i, c)
        return sample_idx  

In [4]:

import tensorflow as tf

def flatten(l):
    return [item for sublist in l for item in sublist]

np.random.seed(42)

class Generator(tf.keras.utils.Sequence):

    def __init__(self, split, batch_size):
        self.studies = [MalafeevStudy(signal_file=f"Matlab/data/{f}", target_file=f"edf_data/{f.replace('.mat','_new.mat')}") for f in split]
        self.indices = flatten([tmp.get_sample_idx(i) for i, tmp in enumerate(self.studies)])
        self.batch_size = batch_size
        np.random.shuffle(self.indices)
        
    def __len__(self):
        return int(np.floor(len(self.indices) / self.batch_size))

    def __getitem__(self, index):
        
        batch_x = np.empty([self.batch_size, self.studies[0].win_len, 1, 3])
        batch_y = np.empty([self.batch_size, 2])

        idxs = self.indices[index*self.batch_size:(index+1)*self.batch_size]
        for i, idx in enumerate(idxs):
            x, y = self.get_study_windows_by_idx(idx)
            batch_x[i,...] = x
            batch_y[i,...] = y
        return batch_x, batch_y

    def get_study_windows_by_idx(self, index):

        return self.studies[index[0]].get_window_by_idx(index[1], index[2])

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

train_data = Generator(split=splits['train'], batch_size=200)



## Train model

In [5]:
from keras import backend

fs = 200
win_sec = 16
win_len = win_sec * fs
n_classes = 2
n_channels = 3


_, model = build_model(win_len, n_channels, n_classes)

In [6]:
from sklearn.utils import compute_class_weight

# Compute class weights
y = flatten([x.y[...,1] for x in train_data.studies])
cls = np.arange(n_classes)
clw = compute_class_weight(class_weight="balanced", classes=cls, y=y)
class_weights = {0: clw[0], 1: clw[1]}

In [7]:
from tensorflow import keras
import os
import pickle

train_model = False
folder = "malafeev42_new"

if not os.path.exists(folder):
    os.mkdir(folder)


weight_file = os.path.join(folder,"CNN_weights.h5")
history_file = os.path.join(folder, "history")



if train_model:

    model.compile(optimizer=keras.optimizers.Nadam(learning_rate=0.002),
                loss=keras.losses.CategoricalCrossentropy(),
                metrics=keras.metrics.CategoricalAccuracy())


    with tf.device("/device:GPU:0"):
        history = model.fit(train_data, 
                            class_weight=class_weights,
                            epochs=3)
    
    import pickle
    model.save_weights(weight_file)

    with open(history_file, 'wb') as file_pi:
        pickle.dump(history.history, file_pi)

else:
    model.load_weights(weight_file)
    model.trainable = False

    with open(history_file, 'rb') as file_pi:
        history = pickle.load(file_pi)


## Evaluate test data

In [8]:

def predict_study(mdl, study: MalafeevStudy):

    n_classes = mdl.output_shape[-1]
    n_channels = mdl.input_shape[-1]

    ch_x = np.empty([study.num_win, study.win_len, 1, n_channels])
    ch_y = np.empty([study.num_win])
    study_preds = np.empty([study.num_win, n_classes, n_channels])

    idxs = study.get_sample_idx(0)

    for ch in range(n_channels):
        
        ch_idx = idxs[np.where(idxs[:,2]==ch)[0],1:3]
        
        for i, idx in enumerate(ch_idx):
            x, y = study.get_window_by_idx(*idx)
            ch_x[i,...] = x
            ch_y[i,...] = y[-1]

        study_preds[...,ch] = mdl.predict_on_batch(ch_x)

    return np.mean(study_preds, axis=2), ch_y    

In [9]:

y_pred = []
y_true = []
ids = []

for fi in splits['test']:
    print(f"Predicting study: {fi}")
    ids.append(fi)

    sig_file = f"Matlab/data/{fi}.mat"
    y_file = f"edf_data/{fi}_new.mat"
    study = MalafeevStudy(signal_file=sig_file, target_file=y_file)

    study_prob, study_y = predict_study(model, study)
    y_pred.append(np.argmax(study_prob, axis=1))
    y_true.append(study_y)


Predicting study: 0pai


ResourceExhaustedError: Graph execution error:

Detected at node 'model_3/model_2/batch_normalization_14/FusedBatchNormV3' defined at (most recent call last):
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\runpy.py", line 196, in _run_module_as_main
      return _run_code(code, main_globals, None,
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\runpy.py", line 86, in _run_code
      exec(code, run_globals)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\traitlets\config\application.py", line 985, in launch_instance
      app.start()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
      self.io_loop.start()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\tornado\platform\asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\asyncio\base_events.py", line 600, in run_forever
      self._run_once()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\asyncio\base_events.py", line 1896, in _run_once
      handle._run()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\asyncio\events.py", line 80, in _run
      self._context.run(self._callback, *self._args)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\kernelbase.py", line 406, in dispatch_shell
      await result
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\ipkernel.py", line 383, in do_execute
      res = shell.run_cell(
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\ipykernel\zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\interactiveshell.py", line 2940, in run_cell
      result = self._run_cell(
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\interactiveshell.py", line 2995, in _run_cell
      return runner(coro)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\async_helpers.py", line 129, in _pseudo_sync_runner
      coro.send(None)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\interactiveshell.py", line 3194, in run_cell_async
      has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\interactiveshell.py", line 3373, in run_ast_nodes
      if await self.run_code(code, result, async_=asy):
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\IPython\core\interactiveshell.py", line 3433, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\SIOS\AppData\Local\Temp\ipykernel_12580\2164236807.py", line 13, in <module>
      study_prob, study_y = predict_study(model, study)
    File "C:\Users\SIOS\AppData\Local\Temp\ipykernel_12580\1204788849.py", line 21, in predict_study
      study_preds[...,ch] = mdl.predict_on_batch(ch_x)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 2474, in predict_on_batch
      outputs = self.predict_function(iterator)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 2041, in predict_function
      return step_function(self, iterator)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 2027, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 2015, in run_step
      outputs = model.predict_step(data)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 1983, in predict_step
      return self(x, training=False)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\training.py", line 557, in __call__
      return super().__call__(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\functional.py", line 510, in call
      return self._run_internal_graph(inputs, training=training, mask=mask)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\functional.py", line 667, in _run_internal_graph
      outputs = node.layer(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\engine\base_layer.py", line 1097, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 850, in call
      outputs = self._fused_batch_norm(inputs, training=training)
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 660, in _fused_batch_norm
      output, mean, variance = control_flow_util.smart_cond(
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\utils\control_flow_util.py", line 108, in smart_cond
      return tf.__internal__.smart_cond.smart_cond(
    File "c:\Users\SIOS\Anaconda3\envs\u-sleep\lib\site-packages\keras\layers\normalization\batch_normalization.py", line 649, in _fused_batch_norm_inference
      return tf.compat.v1.nn.fused_batch_norm(
Node: 'model_3/model_2/batch_normalization_14/FusedBatchNormV3'
OOM when allocating tensor with shape[2400,32,3200,1] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node model_3/model_2/batch_normalization_14/FusedBatchNormV3}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_predict_function_2870]

In [None]:
from sklearn.metrics import recall_score, precision_score, f1_score, cohen_kappa_score, confusion_matrix

y_true=[(y==1)*1 for y in y_true]
y_hat = np.concatenate(y_pred)
y = np.concatenate(y_true)==1

recall = recall_score(y, y_hat)
precision = precision_score(y, y_hat)
f1 = f1_score(y, y_hat)
kappa = cohen_kappa_score(y, y_hat)

print(f"Recall:\t\t{recall:.2f}")
print(f"Precision:\t{precision:.2f}")
print(f"F1-Score:\t{f1:.2f}")
print(f"Cohen's kappa:\t{kappa:.2f}")

Recall:		0.63
Precision:	0.88
F1-Score:	0.73
Cohen's kappa:	0.71


In [None]:
from scipy.io import savemat

ids = splits['test']
out = {"id": ids,
    "yHat": y_pred,
    "yTrue": y_true}

savemat("Matlab/malafeev42_new.mat", mdict=out)


  narr = np.asanyarray(source)


# Evaluate train data to check overfitting


In [None]:
train_pred = []
train_true = []
train_ids = []

for fi in splits['train']:
    print(f"Predicting study: {fi}")
    train_ids.append(fi)

    sig_file = f"Matlab/data/{fi}.mat"
    y_file = f"edf_data/{fi}_new.mat"
    study = MalafeevStudy(signal_file=sig_file, target_file=y_file)

    study_prob, study_y = predict_study(model, study)
    train_pred.append(np.argmax(study_prob, axis=1))
    train_true.append(study_y)

Predicting study: AXbm
Predicting study: hT38
Predicting study: Msy4
Predicting study: pPpj
Predicting study: go56
Predicting study: SOZ3
Predicting study: X7s0
Predicting study: muls
Predicting study: 5bSg
Predicting study: Nzhl
Predicting study: C1Wu
Predicting study: cblr
Predicting study: svlu
Predicting study: YHLr
Predicting study: EHED
Predicting study: G7PJ
Predicting study: DYYI
Predicting study: RfL0
Predicting study: ZYFG
Predicting study: 0ncr
Predicting study: iSqw
Predicting study: Dr51
Predicting study: UwK6
Predicting study: Ivfn
Predicting study: LR2s
Predicting study: Zpwh
Predicting study: DjrT
Predicting study: bkx9
Predicting study: MS6u
Predicting study: fNe4
Predicting study: 9098
Predicting study: EMcQ
Predicting study: YOh8
Predicting study: UsSz
Predicting study: y5We
Predicting study: 9JQY
Predicting study: zaca
Predicting study: 3J4W
Predicting study: sNMf
Predicting study: d3ET
Predicting study: JCpz
Predicting study: oOMR
Predicting study: RM1S
Predicting 

In [None]:
out = {"id": train_ids,
    "yHat": train_pred,
    "yTrue": train_true}

savemat("Matlab/training_performance/malafeev42_new.mat", mdict=out)


  narr = np.asanyarray(source)
