In [3]:
import tensorflow as tf
import numpy as np
from tensorflow import keras

class InceptionBlock(tf.keras.Model):
    def __init__(self, channel_in):
        super().__init__()
        
        channel = channel_in
        
        self.conv1 = keras.layers.Conv2D(channel//2, kernel_size=(1,1), padding='same')
        self.bn1 = keras.layers.BatchNormalization()
        self.av1 = keras.layers.Activation(tf.nn.relu)
        
        self.conv2_1 = keras.layers.Conv2D(channel//4, kernel_size=(1,1), padding='same')
        self.conv2_2 = keras.layers.Conv2D((channel//3)*2, kernel_size=(3,3), padding='same')
        self.bn2 = keras.layers.BatchNormalization()
        self.av2 = keras.layers.Activation(tf.nn.relu)
        
        self.conv3_1 = keras.layers.Conv2D(channel//4, kernel_size=(1,1), padding='same')
        self.conv3_2 = keras.layers.Conv2D((channel//7)*4, kernel_size=(5,5), padding='same')
        self.bn3 = keras.layers.BatchNormalization()
        self.av3 = keras.layers.Activation(tf.nn.relu)
        
        self.pool = keras.layers.MaxPool2D(pool_size = (3, 3), strides = 1, padding="same")
        self.pool_conv = keras.layers.Conv2D(channel//4, kernel_size=(1, 1), padding="same")
        
    def call(self, x):
        x1 = self.conv1(x)
        x1 = self.bn1(x1)
        x1 = self.av1(x1)
        
        x2 = self.conv2_1(x)
        x2 = self.conv2_2(x2)
        x2 = self.bn2(x2)
        x2 = self.av2(x2)
        
        x3 = self.conv3_1(x)
        x3 = self.conv3_2(x3)
        x3 = self.bn3(x3)
        x3 = self.av3(x3)
        
        x4 = self.pool(x)
        x4 = self.pool_conv(x4)
               
        
        x = tf.concat([x1, x2, x3, x4], axis = 3)
        return x
    
class InceptionNet(tf.keras.Model):
    def __init__(self, input_shape, output_dim):
        super().__init__()
        
        self._layers = [
            keras.layers.Conv2D(64, kernel_size=(3, 3), padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.Activation(tf.nn.relu),
            keras.layers.MaxPool2D(pool_size = (2, 2), strides=2),
            
            keras.layers.Conv2D(16, kernel_size=(3, 3), padding="same"),
            keras.layers.BatchNormalization(),
            keras.layers.Activation(tf.nn.relu),
            keras.layers.MaxPool2D(pool_size = (2, 2), strides=2),
            
            #Inception Modules
            InceptionBlock(channel_in = 16),
            InceptionBlock(channel_in = 32),
            InceptionBlock(channel_in = 59),
            InceptionBlock(channel_in = 115),
            InceptionBlock(channel_in = 226),
            keras.layers.GlobalAveragePooling2D(),
            keras.layers.Dense(1000, activation = tf.nn.relu),
            keras.layers.Dense(output_dim, activation = tf.nn.softmax)
        ]
        
    def call(self, x):
        for layer in self._layers:
            if isinstance(layer, list):
                for l in layer:
                    x = l(x)
            else:
                x = layer(x)
        return x
                                