In [8]:
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
from keras.layers import Activation, AveragePooling2D, BatchNormalization, Concatenate, Conv2D, Dense, GlobalAveragePooling2D, GlobalMaxPooling2D, Input, Lambda, MaxPooling2D
from keras import backend as K
from tensorflow.keras.applications import InceptionResNetV2
from keras.datasets import mnist
import numpy as np


In [9]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = np.expand_dims(x_train,-1)
x_test = np.expand_dims(x_test,-1)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

In [32]:
input_shape = (224,224,3)
embedding_size = 128

In [25]:
def modified_inception_resnet():
    inputs = Input(shape=input_shape)
    layer = Conv2D(64, kernel_size=7, strides=(2,2), padding='same')(inputs)
    layer = MaxPooling2D(3, 2, padding='same')(layer)
    layer = BatchNormalization()(layer)
    model = Model(inputs=inputs, outputs=layer)
    return model

model = modified_inception_resnet()
model.summary()


Model: "functional_15"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_10 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv2d_413 (Conv2D)          (None, 112, 112, 64)      9472      
_________________________________________________________________
max_pooling2d_12 (MaxPooling (None, 56, 56, 64)        0         
_________________________________________________________________
batch_normalization_410 (Bat (None, 56, 56, 64)        256       
Total params: 9,728
Trainable params: 9,600
Non-trainable params: 128
_________________________________________________________________


In [27]:
def conv2d_bn(x, filters, kernel_size, strides=1, padding='same', activation='relu', use_bias=False):
    x = Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias)(x)
    if not use_bias:
        bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
        x = BatchNormalization(axis=bn_axis, scale=False)(x)
    if activation is not None:
        x = Activation(activation)(x)
    return x

In [30]:
def inception_resnet_block(x, scale, block_type, activation='relu'):
    if block_type == "A":
        branch_0 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(branch_1, 32, 3)
        branch_2 = conv2d_bn(x, 32, 1)
        branch_2 = conv2d_bn(branch_2, 48, 3)
        branch_2 = conv2d_bn(branch_2, 64, 3)
        branches = [branch_0, branch_1, branch_2]
    elif block_type == "B":
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 128, 1)
        branch_1 = conv2d_bn(branch_1, 160, (1, 7))
        branch_1 = conv2d_bn(branch_1, 192, (7, 1))
        branches = [branch_0, branch_1]
    elif block_type == "C":
        branch_0 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(x, 192, 1)
        branch_1 = conv2d_bn(branch_1, 224, (1,3))
        branch_1 = conv2d_bn(branch_1, 256, (3,1))
        branches = [branch_0, branch_1]
    else:
        raise ValueError('Invalid Block Type')
    concat = Concatenate()(branches)
    channel_axis = 1 if K.image_data_format() == "channels_first" else 3
    up = conv2d_bn(concat, K.int_shape(x)[channel_axis], 1, activation=None, use_bias=True)
    x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale, 
               output_shape = K.int_shape(x)[1:], 
               arguments = {'scale': scale})([x, up])
    if activation is not None: 
        x = Activation(activation)(x)
    return x    

In [31]:
def InceptionResNetV2():
    inputs = Input(shape=input_shape)
    
    # Stem
    x = conv2d_bn(inputs, 32, 3, strides=2, padding='valid')
    x = conv2d_bn(x, 32, 3, padding='valid')
    x = conv2d_bn(x, 64, 3)
    branch_0 = MaxPooling2D(3, strides=2, padding='valid')(x)
    branch_1 = conv2d_bn(x, 96, 3, strides=2, padding='valid')
    branches = [branch_0, branch_1]
    concat = Concatenate()(branches)
    branch_0 = conv2d_bn(concat, 64, 1)
    branch_0 = conv2d_bn(branch_0, 96, 3, padding='valid')
    branch_1 = conv2d_bn(concat, 64, 1)
    branch_1 = conv2d_bn(branch_1, 64, (7,1))
    branch_1 = conv2d_bn(branch_1, 64, (1,7))
    branch_1 = conv2d_bn(branch_1, 96, 3, padding='valid')
    branches = [branch_0, branch_1]
    concat = Concatenate()(branches)
    branch_0 = conv2d_bn(concat, 192, 3, padding='valid')
    branch_1 = MaxPooling2D(strides=2, padding='valid')
    branches = [branch_0, branch_1]
    x = Concatenate()(branches)
    
    # 5 x Block A 
    for _ in range(5):
        x = inception_resnet_block(x, 0.1, block_type='A')
    
    # Reduction-A
    branch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 256, 3)
    branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_pool]
    x = Concatenate()(branches)
    
    # 10 x Block B 
    for _ in range(10):
        x = inception_resnet_block(x, 0.1, block_type='B')
    
    # Reduction-B
    branch_0 = conv2d_bn(x, 256, 1)
    branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')
    branch_1 = conv2d_bn(x, 256, 1)
    branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')
    branch_2 = conv2d_bn(x, 256, 1)
    branch_2 = conv2d_bn(branch_2, 288, 3)
    branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    x = Concatenate()(branches)
    
    # 5 x Block C
    for _ in range(5):
        x = inception_resnet_block(x, 0.1, block_type='C')

    # Average Pool and Fully connected layer
    x = GlobalAveragePooling2D()(x)
    x = Dense(embedding_size)(x)
    x = Lambda(lambda  x: K.l2_normalize(x,axis=1))(x)
    model = Model(inputs, x)
    return model
    
    
    

In [None]:
model = InceptionResNetv2()
model