**04. Building the inception block**

<img src = 'INCEPTION_BLOCK.png' height = '400' />

In [45]:
# DL needs
import tensorflow as tf
import keras as kr

# Data needs
import pandas as pd
from sklearn.model_selection import train_test_split

# Numerical computation needs
import numpy as np

# plotting needs
import matplotlib.pyplot as plt
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# ensuring reproducibility
random_seed=42
tf.random.set_seed(random_seed)

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="keras")



In [46]:
@kr.utils.register_keras_serializable(package='InceptionBlock')
class InceptionBlock(kr.layers.Layer):
    def __init__(self,input_channels,output_channels,expansion_factor,conv_stride = 1,expansion_kernel_size = 1,depthwise_width = 3, depthwise_kernel_sizes = [3,3,3],depthwise_dilation_rates = [1,2,3], transpose_kernel_size=3,transpose_stride=2, **kwargs):
        super().__init__(**kwargs)

        self._block_name = kwargs.get('name','InceptionBlock')
        self.conv_stride = conv_stride
        self.transpose_stride = transpose_stride

        expanded_channels = input_channels * expansion_factor
        self.depthwise_width = depthwise_width

        # Expansion 
        self.expand_conv = tf.keras.layers.Conv2D(
            filters = expanded_channels,
            kernel_size = expansion_kernel_size,
            padding = 'same',
            use_bias = False,
            name = f'{self._block_name}_expand'
        )

        self.bn_e = tf.keras.layers.BatchNormalization(name = f'{self._block_name}_expand_BN')
        self.relu_e = tf.keras.layers.ReLU(name=f'{self._block_name}_expand_relu')

        # Depthwise convolution x 3
        self.depthwise = [
            [tf.keras.layers.DepthwiseConv2D(
                kernel_size = depthwise_kernel_sizes[0],
                strides = self.conv_stride,
                padding = 'same',
                use_bias = False,
                name = f'{self._block_name}_depthwise_{i+1}',
                dilation_rate = depthwise_dilation_rates[0]
            ),
            tf.keras.layers.BatchNormalization(name = f'{self._block_name}_depthwise_BN_{i+1}')
            ]
            for i in range(depthwise_width)
        ]

        # Concat layer
        self.concat = tf.keras.layers.Concatenate(name =f'{self._block_name}_concat' )
        
        # Concat relu
        self.relu_concat = tf.keras.layers.ReLU(name=f'{self._block_name}_concat_relu')

        # Projection
        self.project_conv = tf.keras.layers.Conv2D(
            filters = output_channels,
            kernel_size = 1,
            padding = 'same',
            use_bias = False,
            name = f'{self._block_name}_project'
        )
        self.bn_p = tf.keras.layers.BatchNormalization(name = f'{self._block_name}_project_BN')

        # Transpose Convolution layer 
        self.transpose_conv = tf.keras.layers.Conv2DTranspose(
            filters = output_channels,
            kernel_size = transpose_kernel_size,
            strides = self.transpose_stride,
            padding = 'same',
            use_bias = False,
            name = f'{self._block_name}_transpose_conv'
        )

        self.bn_tc = tf.keras.layers.BatchNormalization(name = f'{self._block_name}_transpose_conv_BN')


    def call(self,inputs,training = False):
        # expansion
        x = self.expand_conv(inputs)
        x = self.bn_e(x,training = training)
        x = self.relu_e(x)

        # depthwise convolution
        depthwise_out = []
        for depthwise_layer in self.depthwise:
            # BatchNorm(DepthwiseConv(x))
            depthwise_out.append(depthwise_layer[1](depthwise_layer[0](x),training = training))
        
        depthwise_out+=[x]
        
        # concatenation
        x = self.concat(depthwise_out)
        x = self.relu_concat(x)

        # projection
        x = self.project_conv(x)
        x = self.bn_p(x,training = training)

        # transpose convolution
        x = self.transpose_conv(x)
        x = self.bn_tc(x,training = training)
        
        return x

In [47]:
inputs = tf.keras.layers.Input(shape=(7,7,320))
outputs = InceptionBlock(input_channels=320,
                         output_channels=192,
                         expansion_factor=6,
                         conv_stride = 1,
                         depthwise_width=3,
                         depthwise_kernel_sizes=[3,5,7],
                         depthwise_dilation_rates=[1,2,3],
                         transpose_kernel_size=3,
                         transpose_stride=2,
                         name = 'Inception_1')(inputs)
 
model = tf.keras.models.Model(inputs = inputs, outputs= outputs)
model.summary()


**05. Decoder creation**
<br>
<img src="QUICKSAL.png" width = "700"/>

**Inception layer params**
* expansion factor = 6 for all inception blocks
* stride (Conv,transConv) = (1,2) in all inception blocks
* I have taken the freedom to do some variations in params, since paper has some details missing (eg: kernel_size of depthwise layer) 
<br>

|Inception layer|Input size|output channels (c)|width of depthwise layer|kernel_size of depthwise layer|
|---|---|---|---|---|
|Inception_5|7<sup>2</sup> x 320|96|2|[1,3]|
|Inception_4|14<sup>2</sup> x 192|32|3|[1,3,5]|
|Inception_3|28<sup>2</sup> x 64|24|3|[1,3,5]|
|Inception_2|56<sup>2</sup> x 48|16|4|[1,3,5,7]|
|Inception_1|112<sup>2</sup> x 32|16|4|[1,3,5,7]|

In [49]:
# decoder creation

@kr.utils.register_keras_serializable(package='QUICKSAL_decoder')
class QUICKSAL_decoder(tf.keras.layers.Layer):
    def __init__(self,**kwargs):
        super().__init__(**kwargs)

        # inception blocks
        self.inc1 = InceptionBlock(input_channels=32,output_channels=16,expansion_factor=6,conv_stride=1, depthwise_width=4,depthwise_kernel_sizes=[1,3,5,7], depthwise_dilation_rates=[1,1,1,1],transpose_kernel_size=3,transpose_stride=2,name='inception_1')

        self.inc2 = InceptionBlock(input_channels=48,output_channels=16,expansion_factor=6,conv_stride=1, depthwise_width=4,depthwise_kernel_sizes=[1,3,5,7], depthwise_dilation_rates=[1,1,1,1],transpose_kernel_size=3,transpose_stride=2,name='inception_2')

        self.inc3 = InceptionBlock(input_channels=64,output_channels=24,expansion_factor=6,conv_stride=1, depthwise_width=3,depthwise_kernel_sizes=[1,3,5], depthwise_dilation_rates=[1,1,1],transpose_kernel_size=3,transpose_stride=2,name='inception_3')

        self.inc4 = InceptionBlock(input_channels=192,output_channels=32,expansion_factor=6,conv_stride=1, depthwise_width=3,depthwise_kernel_sizes=[1,3,5], depthwise_dilation_rates=[1,1,1],transpose_kernel_size=3,transpose_stride=2,name='inception_4')

        self.inc5 = InceptionBlock(input_channels=320,output_channels=96,expansion_factor=6,conv_stride=1, depthwise_width=2,depthwise_kernel_sizes=[1,3], depthwise_dilation_rates=[1,1],transpose_kernel_size=3,transpose_stride=2,name='inception_5')

        # concat blocks
        self.concat_1 = tf.keras.layers.Concatenate(name = 'dec_concat_1')
        self.concat_2 = tf.keras.layers.Concatenate(name = 'dec_concat_2')
        self.concat_3 = tf.keras.layers.Concatenate(name = 'dec_concat_3')
        self.concat_4 = tf.keras.layers.Concatenate(name = 'dec_concat_4')
    
    def call(self,inputs):
        conv1_out,bir2_out,bir3_out,bir5_out,bir7_out = inputs
        inc5_out = self.inc5(bir7_out)

        inc4_in = self.concat_4([bir5_out,inc5_out])
        inc4_out = self.inc4(inc4_in)

        inc3_in = self.concat_3([bir3_out,inc4_out])
        inc3_out = self.inc3(inc3_in)
        
        inc2_in = self.concat_2([bir2_out,inc3_out])
        inc2_out = self.inc2(inc2_in)

        inc1_in = self.concat_1([conv1_out,inc2_out])
        inc1_out = self.inc1(inc1_in)

        return inc1_out

In [51]:
inputs = [tf.keras.layers.Input((112,112,32),name = 'conv1'),
          tf.keras.layers.Input((56,56,24),name = 'bir2'),
          tf.keras.layers.Input((28,28,32),name = 'bir3'),
          tf.keras.layers.Input((14,14,96),name = 'bir5'),
          tf.keras.layers.Input((7,7,320),name = 'bir7')
          ]
outputs = QUICKSAL_decoder(name = 'QUICKSAL_decoder')(inputs)
model = tf.keras.models.Model(inputs = inputs,outputs=outputs)
model.summary()


***-- CONTD IN NEXT NOTEBOOK --***