In [32]:
import tensorflow as tf
import math
from tensorflow.keras.layers import Conv2D, BatchNormalization, Dense, Activation, Add, GlobalAveragePooling2D, Input
from tensorflow.keras.layers import ZeroPadding2D, DepthwiseConv2D, multiply, Reshape, Reshape, Dropout, add
import tensorflow.keras.backend as K
from tensorflow.keras.models import Model
from copy import deepcopy

In [33]:
# 每个Blocks的参数
DEFAULT_BLOCKS_ARGS = [
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 32, 'filters_out': 16,
     'expand_ratio': 1, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
    {'kernel_size': 3, 'repeats': 2, 'filters_in': 16, 'filters_out': 24,
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
    {'kernel_size': 5, 'repeats': 2, 'filters_in': 24, 'filters_out': 40,
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
    {'kernel_size': 3, 'repeats': 3, 'filters_in': 40, 'filters_out': 80,
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
    {'kernel_size': 5, 'repeats': 3, 'filters_in': 80, 'filters_out': 112,
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25},
    {'kernel_size': 5, 'repeats': 4, 'filters_in': 112, 'filters_out': 192,
     'expand_ratio': 6, 'id_skip': True, 'strides': 2, 'se_ratio': 0.25},
    {'kernel_size': 3, 'repeats': 1, 'filters_in': 192, 'filters_out': 320,
     'expand_ratio': 6, 'id_skip': True, 'strides': 1, 'se_ratio': 0.25}
]

In [34]:
# 两个Kernel的初始化器
CONV_KERNEL_INITIALIZER = {
    'class_name': 'VarianceScaling',
    'config': {
        'scale': 2.0,
        'mode': 'fan_out',
        'distribution': 'normal'
    }
}

DENSE_KERNEL_INITIALIZER = {
    'class_name': 'VarianceScaling',
    'config': {
        'scale': 1. / 3.,
        'mode': 'fan_out',
        'distribution': 'uniform'
    }
}

In [35]:
# 用于计算padding的大小
def correct_pad(inputs, kernel_size):
    img_dim = 1
    input_size = K.int_shape(inputs)[img_dim:(img_dim + 2)]

    if isinstance(kernel_size, int):
        kernel_size = (kernel_size, kernel_size)

    if input_size[0] is None:
        adjust = (1, 1)
    else:
        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)

    correct = (kernel_size[0] // 2, kernel_size[1] // 2)

    return ((correct[0] - adjust[0], correct[0]),
            (correct[1] - adjust[1], correct[1]))

In [36]:
def block(inputs, activation_fn=tf.nn.swish, drop_rate=0., name='',
          filters_in=32, filters_out=16, kernel_size=3, strides=1,
          expand_ratio=1, se_ratio=0., id_skip=True):
    
    # 决定是否进行升维，第一个blocks不升
    filters = expand_ratio * filters_in
    
    bn_axis = 3
    if expand_ratio != 1:
        x = Conv2D(filters, 
                   1, 
                   padding='same', 
                   use_bias=False, 
                   kernel_initializer=CONV_KERNEL_INITIALIZER,
                   name=name + 'expand_conv')(inputs)
        x = BatchNormalization(axis=bn_axis, name=name + 'expand_bn')(x)
        x = Activation(activation_fn, name=name + 'expand_activation')(x)
    
    else:
        x = inputs
        
    # padding    
    if strides == 2:
        x = ZeroPadding2D(padding=correct_pad(x, kernel_size),
                                 name=name + 'dwconv_pad')(x)
        conv_pad = 'valid'
    else:
        conv_pad = 'same'
        
    x = DepthwiseConv2D(kernel_size, 
                       strides = strides, 
                       padding = conv_pad, 
                       use_bias = False, 
                       depthwise_initializer=CONV_KERNEL_INITIALIZER,
                       name=name + 'dwconv')(x)
    x = BatchNormalization(bn_axis, name=name+'bn')(x)
    x = Activation(activation_fn, name=name+'activation')(x)
    
    if 0 < se_ratio <= 1:
        filters_se = max(1, int(filters_in * se_ratio))
        se = GlobalAveragePooling2D(name=name + 'se_squeeze')(x)
        se = Reshape((1, 1, filters), name=name + 'se_reshape')(se)
        se = Conv2D(filters_se, 1,
                           padding='same',
                           activation=activation_fn,
                           kernel_initializer=CONV_KERNEL_INITIALIZER,
                           name=name + 'se_reduce')(se)
        se = Conv2D(filters, 1,
                           padding='same',
                           activation='sigmoid',
                           kernel_initializer=CONV_KERNEL_INITIALIZER,
                           name=name + 'se_expand')(se)
        x = multiply([x, se], name=name + 'se_excite')
        
        x = Conv2D(filters_out, 
                   (1, 1), 
                   padding='same', 
                   use_bias=False, 
                   kernel_initializer=CONV_KERNEL_INITIALIZER,
                   name=name + 'project_conv')(x)
        x = BatchNormalization(bn_axis, name=name + 'project_bn')(x)
        
        if (id_skip is True and strides == 1 and filters_in == filters_out):
            if drop_rate > 0:
                x = Dropout(drop_rate,
                               noise_shape=(None, 1, 1, 1),
                               name=name + 'drop')(x)
            x = add([x, inputs], name=name + 'add')

        return x

In [37]:
def EfficientNet(width_coefficient,
                 depth_coefficient,
                 default_size,
                 dropout_rate=0.2,
                 drop_connect_rate=0.2,
                 depth_divisor=8,
                 activation_fn=tf.nn.swish,
                 blocks_args=DEFAULT_BLOCKS_ARGS,
                 model_name='efficientnet',
                 weights='imagenet',
                 input_tensor=None,
                 input_shape=None,
                 pooling=None,
                 classes=1000,
                 **kwargs):
    
    input_shape = [416,416,3]
    img_input = Input(tensor=input_tensor, shape=input_shape)
    
    bn_axis = 3 
    
    # 保证filter的大小可以被8整除
    def round_filters(filters, divisor=depth_divisor):
        """Round number of filters based on depth multiplier."""
        filters *= width_coefficient
        new_filters = max(divisor, int(filters + divisor / 2) // divisor * divisor)
        # Make sure that round down does not go down by more than 10%.
        if new_filters < 0.9 * filters:
            new_filters += divisor
        return int(new_filters)
    
    # 重复次数，取顶
    def round_repeats(repeats):
        return int(math.ceil(depth_coefficient * repeats))
    
    
    # Build stem
    x = img_input
    x = ZeroPadding2D(padding=correct_pad(x, 3),
                             name='stem_conv_pad')(x)
    x = Conv2D(round_filters(32), 3,
                      strides=2,
                      padding='valid',
                      use_bias=False,
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      name='stem_conv')(x)
    x = BatchNormalization(axis=bn_axis, name='stem_bn')(x)
    x = Activation(activation_fn, name='stem_activation')(x)
    
    
    # 防止参数的改变
    blocks_args = deepcopy(blocks_args)
    
    b = 0
    # 计算总的block的数量
    blocks = float(sum(args['repeats'] for args in blocks_args))
    
    # 遍历每一个模块的层
    for (i, args) in enumerate(blocks_args):
        assert args['repeats'] > 0
        args['filters_in'] = round_filters(args['filters_in'])
        args['filters_out'] = round_filters(args['filters_out'])
        
        # 重复repeats次数
        for j in range(round_repeats(args.pop('repeats'))):
            if j > 0:
                args['strides'] = 1
                args['filters_in'] = args['filters_out']
            x = block(x, activation_fn, drop_connect_rate * b / blocks,
                      name='block{}{}_'.format(i + 1, chr(j + 97)), **args)
            b += 1
    
        # 收尾工作
    x = Conv2D(round_filters(1280), 1,
                      padding='same',
                      use_bias=False,
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      name='top_conv')(x)
    x = BatchNormalization(axis=bn_axis, name='top_bn')(x)
    x = Activation(activation_fn, name='top_activation')(x)

    # 利用GlobalAveragePooling2D代替全连接层
    x = GlobalAveragePooling2D(name='avg_pool')(x)
    if dropout_rate > 0:
        x = Dropout(dropout_rate, name='top_dropout')(x)

    x = Dense(classes,
                    activation='softmax',
                    kernel_initializer=DENSE_KERNEL_INITIALIZER,
                    name='probs')(x)

    # 输入inputs
    inputs = img_input

    model = Model(inputs, x, name=model_name)
    
    # Load weights.
    if weights == 'imagenet':
        file_suff = '_weights_tf_dim_ordering_tf_kernels_autoaugment.h5'
        file_hash = WEIGHTS_HASHES[model_name[-2:]][0]
        file_name = model_name + file_suff
        weights_path = get_file(file_name,BASE_WEIGHTS_PATH + file_name,
                                    cache_subdir='models',
                                    file_hash=file_hash)
        model.load_weights(weights_path)
    elif weights is not None:
        model.load_weights(weights)
    
    return model

In [38]:
def EfficientNetB0(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.0, 1.0, 224, 0.2,
                        model_name='efficientnet-b0',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB1(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.0, 1.1, 240, 0.2,
                        model_name='efficientnet-b1',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB2(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.1, 1.2, 260, 0.3,
                        model_name='efficientnet-b2',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB3(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.2, 1.4, 300, 0.3,
                        model_name='efficientnet-b3',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB4(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.4, 1.8, 380, 0.4,
                        model_name='efficientnet-b4',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB5(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.6, 2.2, 456, 0.4,
                        model_name='efficientnet-b5',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB6(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(1.8, 2.6, 528, 0.5,
                        model_name='efficientnet-b6',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)


def EfficientNetB7(weights=None,
                   input_tensor=None,
                   input_shape=None,
                   pooling=None,
                   classes=1000,
                   **kwargs):
    return EfficientNet(2.0, 3.1, 600, 0.5,
                        model_name='efficientnet-b7',
                        weights=weights,
                        input_tensor=input_tensor, input_shape=input_shape,
                        pooling=pooling, classes=classes,
                        **kwargs)

In [40]:
model = EfficientNetB0()
model.summary()

Model: "efficientnet-b0"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_6 (InputLayer)            [(None, 416, 416, 3) 0                                            
__________________________________________________________________________________________________
stem_conv_pad (ZeroPadding2D)   (None, 417, 417, 3)  0           input_6[0][0]                    
__________________________________________________________________________________________________
stem_conv (Conv2D)              (None, 208, 208, 32) 864         stem_conv_pad[0][0]              
__________________________________________________________________________________________________
stem_bn (BatchNormalization)    (None, 208, 208, 32) 128         stem_conv[0][0]                  
____________________________________________________________________________________