In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split

from tensorflow.keras.layers import Input, Conv1D, Activation, BatchNormalization
from tensorflow.keras.layers import MaxPool1D, Dropout, GlobalAveragePooling1D, Dense
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping

2024-01-30 02:16:19.432633: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-30 02:16:19.464827: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-30 02:16:19.464857: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-30 02:16:19.465668: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-01-30 02:16:19.471040: I tensorflow/core/platform/cpu_feature_guar

In [2]:
wave_type = ['gamma', 'beta', 'alpha', 'theta', 'delta']

In [3]:
df = pd.DataFrame()
for wt in wave_type:
    df2 = pd.read_csv(wt + '.csv')
    df = pd.concat([df, df2.drop('label', axis=1)], axis=1)

df['label'] = df2['label']
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,7.1,8.1,9.1,10,11,12,13,14,15,label
0,88.5,88.6,88.6,88.6,88.6,88.5,88.4,88.3,88.1,88.0,...,83.5,83.6,83.5,83.3,83.2,83.1,83.0,82.9,82.8,0.0
1,87.3,87.2,87.2,87.1,87.1,87.0,86.9,86.8,86.8,86.7,...,83.8,83.9,84.0,84.0,84.0,83.9,83.7,83.6,83.5,0.0
2,86.1,85.9,85.7,85.5,85.4,85.3,85.3,85.3,85.3,85.2,...,82.5,82.5,82.6,82.8,83.0,83.1,83.2,83.4,83.6,0.0
3,85.4,85.3,85.2,85.2,85.1,85.0,85.1,85.1,85.1,85.0,...,83.8,83.7,83.6,83.6,83.8,84.0,84.2,84.5,84.6,0.0
4,85.2,85.2,85.2,85.1,85.1,85.1,85.2,85.3,85.3,85.4,...,84.6,84.7,84.6,84.5,84.5,84.5,84.6,84.5,84.4,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
164,91.3,91.4,91.4,91.4,91.3,91.2,91.2,91.1,91.0,91.1,...,93.9,93.8,93.7,93.6,93.5,93.4,93.5,93.4,93.4,0.0
165,92.8,93.0,93.1,93.1,93.0,92.8,92.8,92.8,92.8,93.0,...,93.5,93.4,93.3,93.1,92.8,92.6,92.4,92.1,92.0,0.0
166,93.7,93.8,93.8,93.9,94.0,94.1,94.2,94.2,94.1,93.9,...,93.1,93.4,93.6,93.8,93.7,93.6,93.4,93.2,93.1,0.0
167,92.5,92.2,92.0,91.8,91.6,91.5,91.3,91.1,90.9,90.7,...,91.8,91.4,91.0,90.4,90.0,89.7,89.4,89.1,88.8,0.0


In [4]:
X, y = df.drop('label', axis=1), df['label']
y = y.astype(int)
M, m = X.max(), X.min()
X = (X - m) / (M - m)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, shuffle=True)

In [5]:
# torch conv1d default stride=1, padding=0
# keras conv1d default stride=1, padding="valid"

def conv_block(output, k, s, p):
    return tf.keras.Sequential([
                Conv1D(output, kernel_size=k, strides=s, padding=p), 
                Activation('gelu'),
                BatchNormalization(),
                Conv1D(output, kernel_size=1), 
                Activation('gelu'),
                BatchNormalization(),
                Conv1D(output, kernel_size=1), 
                Activation('gelu'),
                BatchNormalization()
            ])

In [6]:
def get_model():
    return tf.keras.Sequential([
        tf.keras.layers.Input(shape=(80,1)),
        conv_block(4, 2, 2, 'valid'),
        MaxPool1D(pool_size=3, strides=2, padding='same'),
        conv_block(16, 2, 1, 'same'),
        MaxPool1D(pool_size=3, strides=2, padding='same'),
        conv_block(64, 2, 1, 'same'),
        MaxPool1D(pool_size=3, strides=2, padding='same'),
        Dropout(0.4),
        conv_block(128, 2, 1, 'same'),
        GlobalAveragePooling1D(),
        Dense(1, 'sigmoid')
    ])

In [7]:
model = get_model()

In [8]:
model.build(input_shape=(None, 80, 1))
model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential (Sequential)     (None, 40, 4)             100       
                                                                 
 max_pooling1d (MaxPooling1  (None, 20, 4)             0         
 D)                                                              
                                                                 
 sequential_1 (Sequential)   (None, 20, 16)            880       
                                                                 
 max_pooling1d_1 (MaxPoolin  (None, 10, 16)            0         
 g1D)                                                            
                                                                 
 sequential_2 (Sequential)   (None, 10, 64)            11200     
                                                                 
 max_pooling1d_2 (MaxPoolin  (None, 5, 64)            

In [14]:
def train(model, epoch, X, y, X_val, y_val):
    es = EarlyStopping(monitor='val_accuracy', patience=20)
    mc = ModelCheckpoint('EEG_single_model_best.h5', monitor='val_accuracy', verbose=0)
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.fit(x=X, y=y, batch_size=8, epochs=epoch, validation_data=(X_val, y_val), callbacks=[es, mc])
    model.save('EEG_single_model_last.h5')

In [15]:
# X_train, X_valid, y_train, y_valid = preprocessing('gamma')
model = get_model()
train(model, 100, X_train, y_train, X_valid, y_valid)

Epoch 1/100
Epoch 2/100

  saving_api.save_model(


Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
