In [1]:
import tensorflow as tf

In [2]:
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout,BatchNormalization,Activation
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Concatenate, Lambda,GlobalAveragePooling2D
from tensorflow.keras import backend as K

def conv2d_bn(x,filters,kernel_size,strides=1,padding='same',activation='relu',use_bias=False,name=None):
    
    x = Conv2D(filters,kernel_size,strides=strides,padding=padding,use_bias=use_bias,name=name)(x)
    
    if not use_bias:
        bn_axis = 1 if K.image_data_format() == 'channels_first' else 3
        bn_name = None if name is None else name + '_bn'
        x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)
    if activation is not None:
        ac_name = None if name is None else name + '_ac'
        x = Activation(activation, name=ac_name)(x)
    return x

def inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):
    if block_type == 'block35':
        branch_0 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(x, 32, 1)
        branch_1 = conv2d_bn(branch_1, 32, 3)
        branch_2 = conv2d_bn(x, 32, 1)
        branch_2 = conv2d_bn(branch_2, 32, 3)
        branch_2 = conv2d_bn(branch_2, 32, 3)
        branches = [branch_0, branch_1, branch_2]
    elif block_type == 'block17':
        branch_0 = conv2d_bn(x, 64, 1)
        branch_1 = conv2d_bn(x, 64, 1)
        branch_1 = conv2d_bn(branch_1, 64, [1, 3])
        branch_1 = conv2d_bn(branch_1, 64, [3, 1])
        branches = [branch_0, branch_1]
    else:
        raise ValueError('Unknown Inception-ResNet block type. '
                         'Expects "block35", "block17" '
                         'but got: ' + str(block_type))

    block_name = block_type + '_' + str(block_idx)
    mixed = Concatenate(name=block_name + '_mixed')(branches)
    up = conv2d_bn(mixed,K.int_shape(x)[3],1,activation=None,use_bias=True,name=block_name + '_conv')

    x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,
               output_shape=K.int_shape(x)[1:],
               arguments={'scale': scale},
               name=block_name)([x, up])
    if activation is not None:
        x = Activation(activation, name=block_name + '_ac')(x)
    return x


def InceptionResNetV2(input_shape=[29,29,1],classes=2):

    inputs = Input(shape=input_shape)

    # Stem block
    x = conv2d_bn(inputs, 32, 3)
    x = conv2d_bn(x, 32, 3, padding='valid')
    x = MaxPooling2D(3, strides=2)(x)
    x = conv2d_bn(x, 64, 1)
    x = conv2d_bn(x, 128,3)
    x = conv2d_bn(x, 128,3)

    # Inception-ResNet-A block
    for block_idx in range(1):
        x = inception_resnet_block(x, scale=0.17, block_type='block35', block_idx=block_idx)

    # Reduction-A block
    branch_0 = conv2d_bn(x, 192, 3, strides=2, padding='valid')
    branch_1 = conv2d_bn(x, 96, 1)
    branch_1 = conv2d_bn(branch_1, 96, 3)
    branch_1 = conv2d_bn(branch_1, 128, 3, strides=2, padding='valid')
    branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)
    branches = [branch_0, branch_1, branch_pool]
    x = Concatenate(name='mixed_6a')(branches)

    # Inception-ResNet-B block
    for block_idx in range(1):
        x = inception_resnet_block(x, scale=0.1, block_type='block17', block_idx=block_idx)


    # Reduction-B block
    branch_0 = conv2d_bn(x, 128, 1)
    branch_0 = conv2d_bn(branch_0, 192, 3, strides=1, padding='valid')
    branch_1 = conv2d_bn(x, 128, 1)
    branch_1 = conv2d_bn(branch_1, 128, 3, strides=1, padding='valid')
    branch_2 = conv2d_bn(x, 128, 1)
    branch_2 = conv2d_bn(branch_2, 128, 3)
    branch_2 = conv2d_bn(branch_2, 128, 3, strides=1, padding='valid')
    branch_pool = MaxPooling2D(3, strides=1, padding='valid')(x)
    branches = [branch_0, branch_1, branch_2, branch_pool]
    x = Concatenate(name='mixed_7a')(branches)


    x = conv2d_bn(x, 896, 1, name='conv_7b')

    x = GlobalAveragePooling2D(name='avg_pool')(x)
    x = Dense(classes, activation='softmax', name='predictions')(x)

    # 创建模型
    model = Model(inputs, x, name='inception_resnet_v2')

    return model

model = InceptionResNetV2([29,29,1],58)
model.summary()


Model: "inception_resnet_v2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 29, 29, 1)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 29, 29, 32)   288         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 29, 29, 32)   96          conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 29, 29, 32)   0           batch_normalization[0][0]        
________________________________________________________________________________