In [1]:
import re
import collections

In [2]:
GlobalParams = collections.namedtuple('GlobalParams', [
    'width_coefficient', 'depth_coefficient', 'image_size', 'dropout_rate',
    'num_classes', 'batch_norm_momentum', 'batch_norm_epsilon',
    'drop_connect_rate', 'depth_divisor', 'min_depth', 'include_top'])

# Parameters for an individual model block
BlockArgs = collections.namedtuple('BlockArgs', [
    'num_repeat', 'kernel_size', 'stride', 'expand_ratio',
    'input_filters', 'output_filters', 'se_ratio', 'id_skip'])

print(GlobalParams)
print(BlockArgs)

<class '__main__.GlobalParams'>
<class '__main__.BlockArgs'>


In [3]:
## 튜플 형식의 클래스를 생성, 모든 필드에 기본값 None을 설정
GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields)
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)

print(GlobalParams)

<class '__main__.GlobalParams'>


In [4]:
def decode_block_string(block_string):
    assert isinstance(block_string, str)
    ops = block_string.split("_")

    options = {}
    for op in ops:
        splits = re.split(r'(\d.*)', op) ## 문자와 숫자를 분리
        print(f"{op} ---> {splits}")

        if len(splits) >= 2:
            key, value = splits[:2]
            options[key] = value
            
    assert (('s' in options and len(options['s']) == 1) or
            (len(options['s']) == 2 and options['s'][0] == options['s'][1]))

    return BlockArgs(
        num_repeat=int(options['r']),
        kernel_size=int(options['k']),
        stride=[int(options['s'][0])],
        expand_ratio=int(options['e']),
        input_filters=int(options['i']),
        output_filters=int(options['o']),
        se_ratio=float(options['se']) if 'se' in options else None,
        id_skip=('noskip' not in block_string))

In [5]:
def decode(string_list):
    assert isinstance(string_list, list)
    blocks_args = []
    for block_string in string_list:
        blocks_args.append(decode_block_string(block_string))

    return blocks_args

In [6]:
blocks_args = ['r1_k3_s11_e1_i32_o16_se0.25',
               'r2_k3_s22_e6_i16_o24_se0.25',
               'r2_k5_s22_e6_i24_o40_se0.25',
               'r3_k3_s22_e6_i40_o80_se0.25',
               'r3_k5_s11_e6_i80_o112_se0.25',
               'r4_k5_s22_e6_i112_o192_se0.25',
               'r1_k3_s11_e6_i192_o320_se0.25']

blocks_args = decode(blocks_args)

r1 ---> ['r', '1', '']
k3 ---> ['k', '3', '']
s11 ---> ['s', '11', '']
e1 ---> ['e', '1', '']
i32 ---> ['i', '32', '']
o16 ---> ['o', '16', '']
se0.25 ---> ['se', '0.25', '']
r2 ---> ['r', '2', '']
k3 ---> ['k', '3', '']
s22 ---> ['s', '22', '']
e6 ---> ['e', '6', '']
i16 ---> ['i', '16', '']
o24 ---> ['o', '24', '']
se0.25 ---> ['se', '0.25', '']
r2 ---> ['r', '2', '']
k5 ---> ['k', '5', '']
s22 ---> ['s', '22', '']
e6 ---> ['e', '6', '']
i24 ---> ['i', '24', '']
o40 ---> ['o', '40', '']
se0.25 ---> ['se', '0.25', '']
r3 ---> ['r', '3', '']
k3 ---> ['k', '3', '']
s22 ---> ['s', '22', '']
e6 ---> ['e', '6', '']
i40 ---> ['i', '40', '']
o80 ---> ['o', '80', '']
se0.25 ---> ['se', '0.25', '']
r3 ---> ['r', '3', '']
k5 ---> ['k', '5', '']
s11 ---> ['s', '11', '']
e6 ---> ['e', '6', '']
i80 ---> ['i', '80', '']
o112 ---> ['o', '112', '']
se0.25 ---> ['se', '0.25', '']
r4 ---> ['r', '4', '']
k5 ---> ['k', '5', '']
s22 ---> ['s', '22', '']
e6 ---> ['e', '6', '']
i112 ---> ['i', '112', '']
o1

In [7]:
for block in blocks_args:
    print(block)

BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=1, input_filters=32, output_filters=16, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=2, kernel_size=3, stride=[2], expand_ratio=6, input_filters=16, output_filters=24, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=2, kernel_size=5, stride=[2], expand_ratio=6, input_filters=24, output_filters=40, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=3, kernel_size=3, stride=[2], expand_ratio=6, input_filters=40, output_filters=80, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=3, kernel_size=5, stride=[1], expand_ratio=6, input_filters=80, output_filters=112, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=4, kernel_size=5, stride=[2], expand_ratio=6, input_filters=112, output_filters=192, se_ratio=0.25, id_skip=True)
BlockArgs(num_repeat=1, kernel_size=3, stride=[1], expand_ratio=6, input_filters=192, output_filters=320, se_ratio=0.25, id_skip=True)


In [None]:
def efficientnet_params(model_name):
    """Map EfficientNet model name to parameter coefficients.

    Args:
        model_name (str): Model name to be queried.

    Returns:
        params_dict[model_name]: A (width,depth,res,dropout) tuple.
    """
    params_dict = {
        # width_coefficient, depth_coefficient, image_size, dropout_rate
        # Coefficients:   width,depth,res,dropout
        'efficientnet-b0': (1.0, 1.0, 224, 0.2),
        'efficientnet-b1': (1.0, 1.1, 240, 0.2),
        'efficientnet-b2': (1.1, 1.2, 260, 0.3),
        'efficientnet-b3': (1.2, 1.4, 300, 0.3),
        'efficientnet-b4': (1.4, 1.8, 380, 0.4),
        'efficientnet-b5': (1.6, 2.2, 456, 0.4),
        'efficientnet-b6': (1.8, 2.6, 528, 0.5),
        'efficientnet-b7': (2.0, 3.1, 600, 0.5),
        'efficientnet-b8': (2.2, 3.6, 672, 0.5),
        'efficientnet-l2': (4.3, 5.3, 800, 0.5),
    }
    return params_dict[model_name]

In [None]:
def round_filters(filters, global_params):
    """Calculate and round number of filters based on width multiplier.
       Use width_coefficient, depth_divisor and min_depth of global_params.

    Args:
        filters: 계산되어야 할 필터(채널)의 수.
        global_params (namedtuple): Global params of the model.

    Returns:
        new_filters: New filters number after calculating.
    """
    multiplier = global_params.width_coefficient

    if not multiplier:
        return filters
    
    divisor = global_params.depth_divisor
    min_depth = global_params.min_depth
    
    filters *= multiplier
    min_depth = min_depth or divisor  ## 최소 깊이가 설정되지 않았다면, 깊이 나누기 값을 사용.
    
    # follow the formula transferred from official TensorFlow implementation
    new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor)

    if new_filters < 0.9 * filters:  # prevent rounding by more than 10%
        new_filters += divisor

    return int(new_filters)