In [None]:
from keras.models import Model
from keras.layers import Input
from keras.layers import Conv2D, DepthwiseConv2D, SeparableConv2D
from keras.layers import Activation
from keras.layers import BatchNormalization, Dropout
from keras.layers import AveragePooling2D
from keras.layers import Dense, Flatten

def EEGNet(F1, D, rate_dropout, input_shape, fs, n_classes):
    """
        EEGNet
        
        
        Inputs:
        
            F1: number of temporal filters in Block 1
            D: depth multiplier (number of spatial filters) in Block1
            rate_dropout: 0.5 for within-subject, and 0.25 for cross-subject
            fs: samping rate
            n_classes: number of classes
        
        ref: 
            V. J. Lawhern, A. J. Solon, N. R. Waytowich, S. M. Gordon, C. P. Hung, and B. J. Lance, 
            “EEGNet: a compact convolutional neural network for EEG-based brain–computer interfaces,” 
            J. Neural Eng., vol. 15, no. 5, p. 056013, Jul. 2018.
    """
    
    # input_shape:(nchns, ntemporal, 1)
    
    nchns = input_shape[1]
    
    ###########
    # Block 1
    ###########
    input = Input(shape = input_shape, name = 'main_input')# input_shape:(n_chns, n_temporal, 1)
    x = Conv2D(F1,(1,int(round(fs / 2))),activation = 'linear', strides = (1,1),padding = 'SAME', 
               data_format = 'channels_last')(input)
    x = BatchNormalization()(x)
    #x = DepthwiseConv2D(int(F1), (nchns,1), activation = 'linear', padding = 'valid',
    #                    data_format = 'channels_last')(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D(pool_size=(1, int(fs/32)))(x) # reduce the sampling rate of the signal to 32Hz
    x = Dropout(rate_dropout)(x)
    
    ###########
    # Block 2
    ###########
    F2 = int(D * F1)
    #x = SeparableConv2D(F2, (1, 16), activation = 'linear', strides = (1,1), padding = 'same', data_format = 'channels_last')(x)
    x = BatchNormalization()(x)
    x = Activation('elu')(x)
    x = AveragePooling2D(pool_size=(1, 8))(x)
    x = Flatten()(x)
    
    
    ############
    # Classifier
    ############
    outputs = Dense(n_classes, activation='softmax')(x)
    
    model = Model(inputs= input, outputs=outputs)

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    
    return model