-
Notifications
You must be signed in to change notification settings - Fork 1.8k
/
efficientnet_lite_builder.py
221 lines (189 loc) · 8.31 KB
/
efficientnet_lite_builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model Builder for EfficientNet Edge Models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import logging
import tensorflow.compat.v1 as tf
import efficientnet_builder
import efficientnet_model
import utils
from lite import efficientnet_lite_model_qat
# Edge models use inception-style MEAN and STDDEV for better post-quantization.
MEAN_RGB = [127.0, 127.0, 127.0]
STDDEV_RGB = [128.0, 128.0, 128.0]
def efficientnet_lite_params(model_name):
"""Get efficientnet params based on model name."""
if '-qat' in model_name:
model_name = model_name[:model_name.find('-qat')]
params_dict = {
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
'efficientnet-lite0': (1.0, 1.0, 224, 0.2),
'efficientnet-lite1': (1.0, 1.1, 240, 0.2),
'efficientnet-lite2': (1.1, 1.2, 260, 0.3),
'efficientnet-lite3': (1.2, 1.4, 280, 0.3),
'efficientnet-lite4': (1.4, 1.8, 300, 0.3),
}
return params_dict[model_name]
_DEFAULT_BLOCKS_ARGS = [
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
'r1_k3_s11_e6_i192_o320_se0.25',
]
def efficientnet_lite(width_coefficient=None,
depth_coefficient=None,
dropout_rate=0.2,
survival_prob=0.8):
"""Creates a efficientnet model."""
global_params = efficientnet_model.GlobalParams(
blocks_args=_DEFAULT_BLOCKS_ARGS,
batch_norm_momentum=0.99,
batch_norm_epsilon=1e-3,
dropout_rate=dropout_rate,
survival_prob=survival_prob,
data_format='channels_last',
num_classes=1000,
width_coefficient=width_coefficient,
depth_coefficient=depth_coefficient,
depth_divisor=8,
min_depth=None,
relu_fn=tf.nn.relu6, # Relu6 is for easier quantization.
# The default is TPU-specific batch norm.
# The alternative is tf.layers.BatchNormalization.
batch_norm=utils.TpuBatchNormalization, # TPU-specific requirement.
clip_projection_output=False,
fix_head_stem=True, # Don't scale stem and head.
local_pooling=True, # special cases for tflite issues.
use_se=False, # SE is not well supported on many lite devices.
use_bfloat16=False) # This flag is only read by QAT version of the model.
return global_params
def get_model_params(model_name, override_params):
"""Get the block args and global params for a given model."""
if model_name.startswith('efficientnet-lite'):
width_coefficient, depth_coefficient, _, dropout_rate = (
efficientnet_lite_params(model_name))
global_params = efficientnet_lite(
width_coefficient, depth_coefficient, dropout_rate)
else:
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
if override_params:
# ValueError will be raised here if override_params has fields not included
# in global_params.
global_params = global_params._replace(**override_params)
decoder = efficientnet_builder.BlockDecoder()
blocks_args = decoder.decode(global_params.blocks_args)
logging.info('global_params= %s', global_params)
return blocks_args, global_params
def build_model(images,
model_name,
training,
override_params=None,
model_dir=None,
fine_tuning=False,
features_only=False,
pooled_features_only=False):
"""A helper function to create a model and return predicted logits.
Args:
images: input images tensor.
model_name: string, the predefined model name.
training: boolean, whether the model is constructed for training.
override_params: A dictionary of params for overriding. Fields must exist in
efficientnet_model.GlobalParams.
model_dir: string, optional model dir for saving configs.
fine_tuning: boolean, whether the model is used for finetuning.
features_only: build the base feature network only (excluding final
1x1 conv layer, global pooling, dropout and fc head).
pooled_features_only: build the base network for features extraction (after
1x1 conv layer and global pooling, but before dropout and fc head).
Returns:
logits: the logits tensor of classes.
endpoints: the endpoints for each layer.
Raises:
When model_name specified an undefined model, raises NotImplementedError.
When override_params has invalid fields, raises ValueError.
"""
assert isinstance(images, tf.Tensor)
assert not (features_only and pooled_features_only)
# For backward compatibility.
if override_params and override_params.get('drop_connect_rate', None):
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
if '-qat' in model_name:
model_name = model_name[:model_name.find('-qat')]
with_quantization_aware_training = True
else:
with_quantization_aware_training = False
if not training or fine_tuning:
if not override_params:
override_params = {}
override_params['batch_norm'] = utils.BatchNormalization
blocks_args, global_params = get_model_params(model_name, override_params)
if model_dir:
param_file = os.path.join(model_dir, 'model_params.txt')
if not tf.gfile.Exists(param_file):
if not tf.gfile.Exists(model_dir):
tf.gfile.MakeDirs(model_dir)
with tf.gfile.GFile(param_file, 'w') as f:
logging.info('writing to %s', param_file)
f.write('model_name= %s\n\n' % model_name)
f.write('global_params= %s\n\n' % str(global_params))
f.write('blocks_args= %s\n\n' % str(blocks_args))
with tf.variable_scope(model_name):
if with_quantization_aware_training:
model = efficientnet_lite_model_qat.FunctionalModel(
model_name, blocks_args, global_params, features_only,
pooled_features_only)
outputs = model(images, training=training)[0]
else:
model = efficientnet_model.Model(blocks_args, global_params)
outputs = model(
images,
training=training,
features_only=features_only,
pooled_features_only=pooled_features_only)
if features_only:
outputs = tf.identity(outputs, 'features')
elif pooled_features_only:
outputs = tf.identity(outputs, 'pooled_features')
else:
outputs = tf.identity(outputs, 'logits')
return outputs, model.endpoints
def build_model_base(images, model_name, training, override_params=None):
"""Create a base feature network and return the features before pooling.
Args:
images: input images tensor.
model_name: string, the predefined model name.
training: boolean, whether the model is constructed for training.
override_params: A dictionary of params for overriding. Fields must exist in
efficientnet_model.GlobalParams.
Returns:
features: base features before pooling.
endpoints: the endpoints for each layer.
Raises:
When model_name specified an undefined model, raises NotImplementedError.
When override_params has invalid fields, raises ValueError.
"""
assert isinstance(images, tf.Tensor)
# For backward compatibility.
if override_params and override_params.get('drop_connect_rate', None):
override_params['survival_prob'] = 1 - override_params['drop_connect_rate']
blocks_args, global_params = get_model_params(model_name, override_params)
with tf.variable_scope(model_name):
model = efficientnet_model.Model(blocks_args, global_params)
features = model(images, training=training, features_only=True)
features = tf.identity(features, 'features')
return features, model.endpoints