In [33]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, Dense, GlobalAveragePooling2D, Input
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Concatenate, Lambda
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

In [46]:
def _conv_blocks(inputs, filters, kernel_size, padding='same', activation='relu', strides=1, name=None, use_bias=False):
    x = Conv2D(filters,
               kernel_size,
               strides=strides,
               padding=padding,
               use_bias=use_bias,
               name=name)(inputs)
    if not use_bias:
        bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
        bn_name = None if name is None else name + '_bn'
        x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
    if activation is not None:
        ac_name = None if name is None else name + '_ac'
        x = Activation(activation, name=ac_name)(x)
    return x

In [47]:
def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
    if block_type == 'block35':
        branch0 = _conv_blocks(x, 32, 1)
        branch1 = _conv_blocks(x, 32, 1)
        branch1 = _conv_blocks(branch1, 32, 3)
        branch2 = _conv_blocks(x, 32, 1)
        branch2 = _conv_blocks(branch2, 48, 3)
        branch2 = _conv_blocks(branch2, 64, 3)
        branches = [branch0, branch1, branch2]
    elif block_type == 'block17':
        branch0 = _conv_blocks(x, 192, 1)
        branch1 = _conv_blocks(x, 128, 1)
        branch1 = _conv_blocks(branch1, 160, [1, 7])
        branch1 = _conv_blocks(branch1, 192, [7, 1])
        branches = [branch0, branch1]
    elif block_type == 'block8':
        branch0 = _conv_blocks(x, 192, 1)
        branch1 = _conv_blocks(x, 192, 1)
        branch1 = _conv_blocks(branch1, 224, [1, 3])
        branch1 = _conv_blocks(branch1, 256, [3, 1])
        branches = [branch0, branch1]
        
    block_name = block_type + '_' + str(block_idx)
    mixed = Concatenate(name=block_name + '_mixed')(branches)
    up = _conv_blocks(mixed, K.int_shape(x)[3],1,activation=None,use_bias=True,name=block_name + '_conv')

    x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
               output_shape=K.int_shape(x)[1:],
               arguments={'scale': scale},
               name=block_name)([x, up])
    if activation is not None:
        x = Activation(activation, name=block_name + '_ac')(x)
    return x

In [48]:
def InceptionResNetV2(input_shape=[299,299,3],
                      classes=1000):
    
    input_shape = [299,299,3]

    img_input = Input(shape=input_shape)
    x = _conv_blocks(img_input, 32, 3, strides=2, padding='valid')
    x = _conv_blocks(x, 32, 3, padding='valid')
    x = _conv_blocks(x, 64, 3)
    x = MaxPooling2D(3, strides=2, padding='valid')(x)
    x = _conv_blocks(x, 80, 1)
    x = _conv_blocks(x, 192, 3, padding='valid')
    x = MaxPooling2D(3, strides=2)(x)
    
    branch0 = AveragePooling2D(3, strides=1, padding='same')(x)
    branch0 = _conv_blocks(branch0, 64, 1)
    branch1 = _conv_blocks(x, 96, 1)
    branch2 = _conv_blocks(x, 48, 1)
    branch2 = _conv_blocks(branch2, 64, 5)
    branch3 = _conv_blocks(x, 64, 1)
    branch3 = _conv_blocks(branch3, 96, 3)
    branch3 = _conv_blocks(branch3, 96, 3)
    
    branches = [branch0, branch1, branch2, branch3]
    x = Concatenate(name='mixed_5b')(branches)
    
    for block_idx in range(1, 11):
        x = inception_resnet_block(x,
                           scale=0.17,
                           block_type='block35',
                           block_idx=block_idx)
        
    branch_0 = _conv_blocks(x, 384, 3, strides=2, padding='valid')
    branch_1 = _conv_blocks(x, 256, 1)
    branch_1 = _conv_blocks(branch_1, 256, 3)
    branch_1 = _conv_blocks(branch_1, 384, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_pool]
    x = Concatenate(name='mixed_6a')(branches)

    # 20次Inception-ResNet-B block: 17 x 17 x 1088 -> 17 x 17 x 1088 
    for block_idx in range(1, 21):
        x = inception_resnet_block(x,
                                   scale=0.1,
                                   block_type='block17',
                                   block_idx=block_idx)


    # Reduction-B block: 17 x 17 x 1088 -> 8 x 8 x 2080
    branch_0 = _conv_blocks(x, 256, 1)
    branch_0 = _conv_blocks(branch_0, 384, 3, strides=2, padding='valid')
    branch_1 = _conv_blocks(x, 256, 1)
    branch_1 = _conv_blocks(branch_1, 288, 3, strides=2, padding='valid')
    branch_2 = _conv_blocks(x, 256, 1)
    branch_2 = _conv_blocks(branch_2, 288, 3)
    branch_2 = _conv_blocks(branch_2, 320, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    x = Concatenate(name='mixed_7a')(branches)
    
    # 10次Inception-ResNet-C block: 8 x 8 x 2080 -> 8 x 8 x 2080
    for block_idx in range(1, 10):
        x = inception_resnet_block(x,
                                   scale=0.2,
                                   block_type='block8',
                                   block_idx=block_idx)
    x = inception_resnet_block(x,
                               scale=1.,
                               activation=None,
                               block_type='block8',
                               block_idx=10)

    # 8 x 8 x 2080 -> 8 x 8 x 1536
    x = _conv_blocks(x, 1536, 1, name='conv_7b')

    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)

    inputs = img_input

    # 创建模型
    model = Model(inputs, x, name='inception_resnet_v2')

    return model

In [49]:
model = InceptionResNetV2()
model.summary()

Model: "inception_resnet_v2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_10 (InputLayer)           [(None, 299, 299, 3) 0                                            
__________________________________________________________________________________________________
conv2d_659 (Conv2D)             (None, 149, 149, 32) 864         input_10[0][0]                   
__________________________________________________________________________________________________
batch_normalization_658 (BatchN (None, 149, 149, 32) 96          conv2d_659[0][0]                 
__________________________________________________________________________________________________
activation_657 (Activation)     (None, 149, 149, 32) 0           batch_normalization_658[0][0]    
________________________________________________________________________________