# Q1 - SwishNet Implementation using Tensorflow

## Import Libraries

In [6]:
import tensorflow as tf
from tensorflow.keras import Sequential, Model, Input
from tensorflow.keras.layers import Conv1D, Multiply, Add, Concatenate, GlobalAveragePooling1D, Activation

import numpy as np
np.random.seed(2021)

In [7]:
tf.__version__

'1.15.0'

In [2]:
def __causal_gated_conv1D(x = None, filters = 16, length = 6, strides = 1):
    def causal_gated_conv1D(x, filters, length, strides):
        x_sigm = Conv1D(filters = filters //2,
                       kernel_size = length,
                       dilation_rate = strides,
                       strides = 1,
                       padding = "causal",
                       activation = "sigmoid")(x)
        
        x_tanh = Conv1D(filters = filters //2,
                       kernel_size = length,
                       dilation_rate = strides,
                       strides = 1,
                       padding = "causal",
                       activation = "tanh")(x)
        
        x_out = Multiply()([x_sigm, x_tanh])
        
        return x_out
    
    if x is None:
        return lambda _x: causal_gated_conv1D(x=_x, filters=filters, length=length, strides=strides)
    else:
        return causal_gated_conv1D(x=x, filters=filters, length=length, strides=strides)

In [3]:
def SwishNet(input_shape, classes, width_multiply=1):
    
    x_input = Input(shape = input_shape)
    
    # 1 block
    x_up = __causal_gated_conv1D(filters=16 * width_multiply, length=3)(x_input)
    x_down = __causal_gated_conv1D(filters=16 * width_multiply, length=6)(x_input)
    x = Concatenate()([x_up, x_down])
    
    # 2 block
    
    x_up = __causal_gated_conv1D(filters=8 * width_multiply, length=3)(x)
    x_down = __causal_gated_conv1D(filters=8 * width_multiply, length=6)(x)
    x = Concatenate()([x_up, x_down])
    
    # 3 block
    x_up = __causal_gated_conv1D(filters=8 * width_multiply, length=3)(x)
    x_down = __causal_gated_conv1D(filters=8 * width_multiply, length=6)(x)
    x_concat = Concatenate()([x_up, x_down])
    
    x = Add()([x, x_concat])
    
    # 4 block
    x_loop1 = __causal_gated_conv1D(filters=16 * width_multiply, length=3, strides=3)(x)
    x = Add()([x, x_loop1])

    # 5 block
    x_loop2 = __causal_gated_conv1D(filters=16 * width_multiply, length=3, strides=2)(x)
    x = Add()([x, x_loop2])

    # 6 block
    x_loop3 = __causal_gated_conv1D(filters=16 * width_multiply, length=3, strides=2)(x)
    x = Add()([x, x_loop3])

    # 7 block
    x_forward = __causal_gated_conv1D(filters=16 * width_multiply, length=3, strides=2)(x)

    # 8 block
    x_loop4 = __causal_gated_conv1D(filters=32 * width_multiply, length=3, strides=2)(x)

    # output
    x = Concatenate()([x_loop2, x_loop3, x_forward, x_loop4])
    x = Conv1D(filters=classes, kernel_size=1)(x)
    x = GlobalAveragePooling1D()(x)
    x = Activation("softmax")(x)

    model = Model(inputs=x_input, outputs=x)

    return model

In [4]:
net = SwishNet(input_shape=(16, 20), classes=2)
net.summary()
print(net.predict(np.random.randn(2, 16, 20)))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 16, 20)       0                                            
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, 16, 8)        488         input_1[0][0]                    
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, 16, 8)        488         input_1[0][0]                    
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, 16, 8)        968         input_1[0][0]                    
___________

[[0.4833677  0.51663226]
 [0.47845703 0.5215429 ]]
