In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K


In [81]:
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# import einops



class TemporalFilter(nn.Module):
    def __init__(
        self,
        n_channels,
        kernel_size,
        srate,
        fmin=None,
        freq=10,
        bandwidth=30,
        margin_bandwidth=25,
        fmin_variety = 12,
        margin_fmin = 4,
        seed=None,
    ):
        super().__init__()
        self.n_channels = n_channels
        self.kernel_size = kernel_size
        self.srate = srate
        self.fmin= fmin
        self.fmin_variety = fmin_variety
        self.margin_fmin = margin_fmin
        self.margin_bandwidth = margin_bandwidth
        self.bandwidth = bandwidth
        self.freq = freq
        
        if self.kernel_size%2 == 0:
            self.register_buffer('_scale', torch.arange(-self.kernel_size//2, self.kernel_size//2 + 1) / self.srate)
        else:
            self.register_buffer('_scale', torch.arange(-self.kernel_size//2 + 1, self.kernel_size//2 + 1) / self.srate)

        if seed is None:
            seed = int(torch.empty((), dtype=torch.int64).random_().item())
            
        if self.bandwidth is None:
            coef_bandwidth = self._create_parameters_bandwidth(self.n_channels, seed)
            self.coef_bandwidth = nn.Parameter(coef_bandwidth)
        else:
            if not isinstance(bandwidth, torch.Tensor):
                bandwidth = torch.tensor(bandwidth, dtype=torch.float32).reshape((1,))
            assert bandwidth.shape[0] in (1, self.n_channels)
            if bandwidth.shape[0] != self.n_channels:
                bandwidth = bandwidth.repeat(self.n_channels)
            self.register_buffer('_bandwidth', bandwidth)

        if self.fmin is None:
            coef_fmin = self._create_parameters_fmin(self.n_channels, seed)
            self.coef_fmin = nn.Parameter(coef_fmin)
        else:
            if not isinstance(fmin, torch.Tensor):
                fmin = torch.tensor(fmin, dtype=torch.float32).reshape((1,))
            assert fmin.shape[0] in (1, self.n_channels)
            if fmin.shape[0] != self.n_channels:
                fmin = fmin.repeat(self.n_channels)
            self.register_buffer('_fmin', fmin)
        
        if self.freq != None:
            if not isinstance(freq, torch.Tensor):
                freq = torch.tensor(freq, dtype=torch.float32).reshape((1,))
            assert freq.shape[0] in (1, self.n_channels)
            if freq.shape[0] != self.n_channels:
                freq = freq.repeat(self.n_channels)
            self.register_buffer('_freq', freq)
    

    def _create_parameters_bandwidth(self, n_coef, seed):
        
        generator = torch.Generator()
        generator.manual_seed(seed+1)
        coef = torch.rand(size=(n_coef,), generator=generator) * self.margin_bandwidth
        
        return coef

    def _create_parameters_fmin(self, n_coef, seed):
        
        generator = torch.Generator()
        generator.manual_seed(seed+1)
        coef = torch.rand(size=(n_coef,), generator=generator)*self.fmin_variety+self.margin_fmin
        
        return coef
    
    def _create_frequencies(self):
        
        if self.bandwidth is None:
            bandwidth = self.coef_bandwidth
        else:
            bandwidth = self._bandwidth
            
        if self.fmin is None:
            fmin = self.coef_fmin
        else:
            fmin = self._fmin


        if self.freq != None:
            freq = self._freq
        else:
            freq = fmin + bandwidth/2

        freq_low = fmin
        freq_high = fmin + bandwidth

        return bandwidth, freq_low, freq_high, freq
    
    
    

In [None]:
class SincLayer1d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='1d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)
        self.register_buffer('_hamming_window', torch.hamming_window(kernel_size).reshape((1,1,-1)))

    def _create_filters(self, freq_low, freq_high):
        _scale = self._scale.reshape((1,1,-1))
        freq_low, freq_high = freq_low.reshape((-1,1,1)), freq_high.reshape((-1,1,1))   
        filt_low = freq_low * torch.special.sinc(2 * freq_low * _scale)
        filt_high = freq_high * torch.special.sinc(2 * freq_high * _scale)
        filt = self._hamming_window * 2 * (filt_high - filt_low) / self.srate
        return filt
        
    def forward(self, x):
        x = self.pad(x)
        _, _, freq_low, freq_high = self._create_frequencies()
        filt = self._create_filters(freq_low, freq_high)
        assert self.in_channels == x.shape[-2]
        x = F.conv1d(x, filt, groups=self.in_channels, padding='valid')
        return x
    
                                     
class SincLayer2d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='2d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)         
        self.register_buffer('_hamming_window', torch.hamming_window(kernel_size).reshape((1,1,1,-1)))
                                     
    def _create_filters(self, freq_low, freq_high):
        _scale = self._scale.reshape((1,1,1,-1))
        freq_low, freq_high = freq_low.reshape((-1,1,1,1)), freq_high.reshape((-1,1,1,1))   
        filt_low = freq_low * torch.special.sinc(2 * freq_low * _scale)
        filt_high = freq_high * torch.special.sinc(2 * freq_high * _scale)
        filt = self._hamming_window * 2 * (filt_high - filt_low) / self.srate
        return filt
        
    def forward(self, x):
        x = self.pad(x)
        _, _, freq_low, freq_high = self._create_frequencies()
        filt = self._create_filters(freq_low, freq_high)
        assert self.in_channels == x.shape[-3]
        x = F.conv2d(x, filt, groups=self.in_channels, padding='valid')
        return x
                                     
    
    
    
class SincHilbertLayer1d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='1d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=True)   
        self.register_buffer('_hamming_window', torch.hamming_window(kernel_size).reshape((1,1,-1)))
        self.hilbert = HilbertLayer()

    def _create_filters(self, freq_low, freq_high):
        _scale = self._scale.reshape((1,1,-1))
        freq_low, freq_high = freq_low.reshape((-1,1,1)), freq_high.reshape((-1,1,1))   
        filt_low = freq_low * torch.special.sinc(2 * freq_low * _scale)
        filt_high = freq_high * torch.special.sinc(2 * freq_high * _scale)
        filt = self._hamming_window * 2 * (filt_high - filt_low) / self.srate
        return filt
        
    def forward(self, x, return_filtered=False):
        x = self.pad(x)
        _, _, freq_low, freq_high = self._create_frequencies()
        filt = self._create_filters(freq_low, freq_high)
        assert self.in_channels == x.shape[-2]
        x = F.conv1d(x, filt, groups=self.in_channels, padding='valid')
            
        if not return_filtered:
            x = self.hilbert(x)
            x = torch.abs(x)
        x = x[...,self.pad.padding_hilbert:-self.pad.padding_hilbert]
        return x
    
    
class SincHilbertLayer2d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='2d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=True)
        self.register_buffer('_hamming_window', torch.hamming_window(kernel_size).reshape((1,1,-1)))
        self.hilbert = HilbertLayer()

    def _create_filters(self, freq_low, freq_high):
        _scale = self._scale.reshape((1,1,1,-1))
        freq_low, freq_high = freq_low.reshape((-1,1,1,1)), freq_high.reshape((-1,1,1,1))   
        filt_low = freq_low * torch.special.sinc(2 * freq_low * self._scale)
        filt_high = freq_high * torch.special.sinc(2 * freq_high * self._scale)
        filt = self._hamming_window * 2 * (filt_high - filt_low) / self.srate
        return filt
        
    def forward(self, x, return_filtered=False):
        x = self.pad(x)
        _, _, freq_low, freq_high = self._create_frequencies()
        filt = self._create_filters(freq_low, freq_high)
        assert self.in_channels == x.shape[-3]
        x = F.conv2d(x, filt, groups=self.in_channels, padding='valid')
            
        if not return_filtered:
            x = self.hilbert(x)
            x = torch.abs(x)
        x = x[...,self.pad.padding_hilbert:-self.pad.padding_hilbert]
        return x


    
    
class WaveletLayer1d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='2d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)     
           
    def _create_filters(self, freq, bandwidth):
        _scale = self._scale.reshape((1,1,-1))
        freq, bandwidth = freq.reshape((-1,1,1)), bandwidth.reshape((-1,1,1))
        sigma2 = (2 * math.log(2)) / (bandwidth * math.pi)**2
        filt = (2 * math.pi * sigma2)**(-1/2) / (self.srate / 2)
        filt = filt * torch.cos(2*math.pi * freq * _scale)
        filt = filt * torch.exp(- _scale**2 / (2 * sigma2))
        return filt
                            
    def forward(self, x):
        x = self.pad(x)
        freq, bandwidth, _, _ = self._create_frequencies()
        filt = self._create_filters(freq, bandwidth)
        assert self.in_channels == x.shape[-2]
        x = F.conv1d(x, filt, groups=self.in_channels, padding='valid')
        return x

    
class WaveletLayer2d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='2d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)     
           
    def _create_filters(self, freq, bandwidth):
        _scale = self._scale.reshape((1,1,1,-1))
        freq, bandwidth = freq.reshape((-1,1,1,1)), bandwidth.reshape((-1,1,1,1))
        sigma2 = (2 * math.log(2)) / (bandwidth * math.pi)**2
        filt = (2 * math.pi * sigma2)**(-1/2) / (self.srate / 2)
        filt = filt * torch.cos(2*math.pi * freq * _scale)
        filt = filt * torch.exp(- _scale**2 / (2 * sigma2))
        return filt
                            
    def forward(self, x):
        x = self.pad(x)
        freq, bandwidth, _, _ = self._create_frequencies()
        filt = self._create_filters(freq, bandwidth)
        assert self.in_channels == x.shape[-3]
        x = F.conv2d(x, filt, groups=self.in_channels, padding='valid')
        return x
        
        
        
        
class ComplexWaveletLayer1d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='1d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)    
           
    def _create_filters(self, freq, bandwidth):
        _scale = self._scale.reshape((1,1,-1))
        freq, bandwidth = freq.reshape((-1,1,1)), bandwidth.reshape((-1,1,1))
        sigma2 = (2 * math.log(2)) / (bandwidth * math.pi)**2
        filt = (2 * math.pi * sigma2)**(-1/2) / (self.srate / 2)
        filt = filt * (torch.exp(1j*2*math.pi * freq * _scale) - torch.exp(-0.5*(2*math.pi * freq)**2))
        filt = filt * torch.exp(- _scale**2 / (2 * sigma2))
        return filt
          
    def forward(self, x, return_filtered=False):
        x = self.pad(x)
        freq, bandwidth, _, _ = self._create_frequencies()
        filt = self._create_filters(freq, bandwidth)
        assert self.in_channels == x.shape[-2]
        
        if return_filtered:
            x = F.conv1d(x, filt.real, groups=self.in_channels, padding='valid')
        else:
            x = x.to(torch.complex64)
            x = F.conv1d(x, filt, groups=self.in_channels, padding='valid')
            x = torch.abs(x)
        return x
    
    
class ComplexWaveletLayer2d(TemporalFilter):
    def __init__(self, in_channels, out_channels, kernel_size, srate, fmin_init, fmax_init, freq=None, bandwidth=None, padding_mode='zeros', seed=None):
        super().__init__(out_channels, kernel_size, srate, fmin_init, fmax_init, freq, bandwidth, seed=seed)
        self.in_channels = in_channels
        self.pad = TemporalPad(padding='same', dim='2d', kernel_size=kernel_size, padding_mode=padding_mode, hilbert=False)    
           
    def _create_filters(self, freq, bandwidth):
        _scale = self._scale.reshape((1,1,1,-1))
        freq, bandwidth = freq.reshape((-1,1,1,1)), bandwidth.reshape((-1,1,1,1))
        sigma2 = (2 * math.log(2)) / (bandwidth * math.pi)**2
        filt = (2 * math.pi * sigma2)**(-1/2) / (self.srate / 2)
        filt = filt * (torch.exp(1j*2*math.pi * freq * _scale) - torch.exp(-0.5*(2*math.pi * freq)**2))
        filt = filt * torch.exp(- _scale**2 / (2 * sigma2))
        return filt
          
    def forward(self, x, return_filtered=False):
        x = self.pad(x)
        freq, bandwidth, _, _ = self._create_frequencies()
        filt = self._create_filters(freq, bandwidth)
        
        assert self.in_channels == x.shape[-3]
        if return_filtered:
            x = F.conv2d(x, filt.real, groups=self.in_channels, padding='valid')
        else:
            x = x.to(torch.complex64)
            x = F.conv2d(x, filt, groups=self.in_channels, padding='valid')
            x = torch.abs(x)
        return x

In [None]:

def EEGNet(nb_classes, Chans = 64, Samples = 128, 
             dropoutRate = 0.5, kernLength = 64, F1 = 8, 
             D = 2, F2 = 16, norm_rate = 0.25, dropoutType = 'Dropout'):
    
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (Chans, Samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense', 
                         kernel_constraint = max_norm(norm_rate))(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)



\\
def EEGNet_SSVEP(nb_classes = 12, Chans = 8, Samples = 256, 
             dropoutRate = 0.5, kernLength = 256, F1 = 96, 
             D = 1, F2 = 96, dropoutType = 'Dropout'):
    if dropoutType == 'SpatialDropout2D':
        dropoutType = SpatialDropout2D
    elif dropoutType == 'Dropout':
        dropoutType = Dropout
    else:
        raise ValueError('dropoutType must be one of SpatialDropout2D '
                         'or Dropout, passed as a string.')
    
    input1   = Input(shape = (Chans, Samples, 1))

    ##################################################################
    block1       = Conv2D(F1, (1, kernLength), padding = 'same',
                                   input_shape = (Chans, Samples, 1),
                                   use_bias = False)(input1)
    block1       = BatchNormalization()(block1)
    block1       = DepthwiseConv2D((Chans, 1), use_bias = False, 
                                   depth_multiplier = D,
                                   depthwise_constraint = max_norm(1.))(block1)
    block1       = BatchNormalization()(block1)
    block1       = Activation('elu')(block1)
    block1       = AveragePooling2D((1, 4))(block1)
    block1       = dropoutType(dropoutRate)(block1)
    
    block2       = SeparableConv2D(F2, (1, 16),
                                   use_bias = False, padding = 'same')(block1)
    block2       = BatchNormalization()(block2)
    block2       = Activation('elu')(block2)
    block2       = AveragePooling2D((1, 8))(block2)
    block2       = dropoutType(dropoutRate)(block2)
        
    flatten      = Flatten(name = 'flatten')(block2)
    
    dense        = Dense(nb_classes, name = 'dense')(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input1, outputs=softmax)



def EEGNet_old(nb_classes, Chans = 64, Samples = 128, regRate = 0.0001,
           dropoutRate = 0.25, kernels = [(2, 32), (8, 4)], strides = (2, 4)):

    input_main   = Input((Chans, Samples))
    layer1       = Conv2D(16, (Chans, 1), input_shape=(Chans, Samples, 1),
                                 kernel_regularizer = l1_l2(l1=regRate, l2=regRate))(input_main)
    layer1       = BatchNormalization()(layer1)
    layer1       = Activation('elu')(layer1)
    layer1       = Dropout(dropoutRate)(layer1)
    
    permute_dims = 2, 1, 3
    permute1     = Permute(permute_dims)(layer1)
    
    layer2       = Conv2D(4, kernels[0], padding = 'same', 
                            kernel_regularizer=l1_l2(l1=0.0, l2=regRate),
                            strides = strides)(permute1)
    layer2       = BatchNormalization()(layer2)
    layer2       = Activation('elu')(layer2)
    layer2       = Dropout(dropoutRate)(layer2)
    
    layer3       = Conv2D(4, kernels[1], padding = 'same',
                            kernel_regularizer=l1_l2(l1=0.0, l2=regRate),
                            strides = strides)(layer2)
    layer3       = BatchNormalization()(layer3)
    layer3       = Activation('elu')(layer3)
    layer3       = Dropout(dropoutRate)(layer3)
    
    flatten      = Flatten(name = 'flatten')(layer3)
    
    dense        = Dense(nb_classes, name = 'dense')(flatten)
    softmax      = Activation('softmax', name = 'softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)



def DeepConvNet(nb_classes, Chans = 64, Samples = 256,
                dropoutRate = 0.5):


    # start the model
    input_main   = Input((Chans, Samples, 1))
    block1       = Conv2D(25, (1, 5), 
                                 input_shape=(Chans, Samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(25, (Chans, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation('elu')(block1)
    block1       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1)
    block1       = Dropout(dropoutRate)(block1)
  
    block2       = Conv2D(50, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
    block2       = Activation('elu')(block2)
    block2       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2)
    block2       = Dropout(dropoutRate)(block2)
    
    block3       = Conv2D(100, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
    block3       = Activation('elu')(block3)
    block3       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3)
    block3       = Dropout(dropoutRate)(block3)
    
    block4       = Conv2D(200, (1, 5),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
    block4       = Activation('elu')(block4)
    block4       = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4)
    block4       = Dropout(dropoutRate)(block4)
    
    flatten      = Flatten()(block4)
    
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)


# need these for ShallowConvNet
def square(x):
    return K.square(x)

def log(x):
    return K.log(K.clip(x, min_value = 1e-7, max_value = 10000))   


def ShallowConvNet(nb_classes, Chans = 64, Samples = 128, dropoutRate = 0.5):


    # start the model
    input_main   = Input((Chans, Samples, 1))
    block1       = Conv2D(40, (1, 13), 
                                 input_shape=(Chans, Samples, 1),
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main)
    block1       = Conv2D(40, (Chans, 1), use_bias=False, 
                          kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
    block1       = Activation(square)(block1)
    block1       = AveragePooling2D(pool_size=(1, 35), strides=(1, 7))(block1)
    block1       = Activation(log)(block1)
    block1       = Dropout(dropoutRate)(block1)
    flatten      = Flatten()(block1)
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
    softmax      = Activation('softmax')(dense)
    
    return Model(inputs=input_main, outputs=softmax)