/
object_detector.py
269 lines (237 loc) · 11.9 KB
/
object_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""APIs to train an object detection model."""
import os
import tempfile
from typing import Dict, Optional, Tuple, TypeVar
import tensorflow as tf
from tensorflow_examples.lite.model_maker.core import compat
from tensorflow_examples.lite.model_maker.core.api.api_util import mm_export
from tensorflow_examples.lite.model_maker.core.data_util import object_detector_dataloader
from tensorflow_examples.lite.model_maker.core.export_format import ExportFormat
from tensorflow_examples.lite.model_maker.core.task import configs
from tensorflow_examples.lite.model_maker.core.task import custom_model
from tensorflow_examples.lite.model_maker.core.task import model_spec as ms
from tensorflow_examples.lite.model_maker.core.task.model_spec import object_detector_spec
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import label_util
from tflite_support.metadata_writers import object_detector as metadata_writer
from tflite_support.metadata_writers import writer_utils
T = TypeVar('T', bound='ObjectDetector')
@mm_export('object_detector.ObjectDetector')
class ObjectDetector(custom_model.CustomModel):
"""ObjectDetector class for inference and exporting to tflite."""
ALLOWED_EXPORT_FORMAT = (ExportFormat.TFLITE, ExportFormat.SAVED_MODEL,
ExportFormat.LABEL)
def __init__(
self,
model_spec: object_detector_spec.EfficientDetModelSpec,
label_map: Dict[int, str],
representative_data: Optional[
object_detector_dataloader.DataLoader] = None
) -> None:
"""Initializes the ObjectDetector class.
Args:
model_spec: Specification for the model.
label_map: Dict, map label integer ids to string label names such as {1:
'person', 2: 'notperson'}. 0 is the reserved key for `background` and
doesn't need to be included in `label_map`. Label names can't be
duplicated.
representative_data: Representative dataset for full integer
quantization. Used when converting the keras model to the TFLite model
with full interger quantization.
"""
super().__init__(model_spec, shuffle=None)
if model_spec.config.label_map and model_spec.config.label_map != label_map:
tf.compat.v1.logging.warn(
'Label map is not the same as the previous label_map in model_spec.')
model_spec.config.label_map = label_map
# TODO(yuqili): num_classes = 1 have some issues during training. Thus we
# make minimum num_classes=2 for now.
model_spec.config.num_classes = max(2, max(label_map.keys()))
self.representative_data = representative_data
def create_model(self) -> tf.keras.Model:
self.model = self.model_spec.create_model()
return self.model
def _get_dataset_and_steps(
self,
data: object_detector_dataloader.DataLoader,
batch_size: int,
is_training: bool,
) -> Tuple[Optional[tf.data.Dataset], int, Optional[str]]:
"""Gets dataset, steps and annotations json file."""
if not data:
return None, 0, None
# TODO(b/171449557): Put this into DataLoader.
dataset = data.gen_dataset(
self.model_spec, batch_size, is_training=is_training)
steps = len(data) // batch_size
return dataset, steps, data.annotations_json_file
def train(self,
train_data: object_detector_dataloader.DataLoader,
validation_data: Optional[
object_detector_dataloader.DataLoader] = None,
epochs: Optional[int] = None,
batch_size: Optional[int] = None) -> tf.keras.Model:
"""Feeds the training data for training."""
if not self.model_spec.config.drop_remainder:
raise ValueError('Must set `drop_remainder=True` during training. '
'Otherwise it will fail.')
batch_size = batch_size if batch_size else self.model_spec.batch_size
# TODO(b/171449557): Upstream this to the parent class.
if len(train_data) < batch_size:
raise ValueError('The size of the train_data (%d) couldn\'t be smaller '
'than batch_size (%d). To solve this problem, set '
'the batch_size smaller or increase the size of the '
'train_data.' % (len(train_data), batch_size))
if validation_data and len(validation_data) < batch_size:
tf.compat.v1.logging.warn(
'The size of the validation_data (%d) is smaller than batch_size '
'(%d). Ignore the validation_data.' %
(len(validation_data), batch_size))
validation_data = None
with self.model_spec.ds_strategy.scope():
self.create_model()
train_ds, steps_per_epoch, _ = self._get_dataset_and_steps(
train_data, batch_size, is_training=True)
validation_ds, validation_steps, val_json_file = self._get_dataset_and_steps(
validation_data, batch_size, is_training=False)
return self.model_spec.train(self.model, train_ds, steps_per_epoch,
validation_ds, validation_steps, epochs,
batch_size, val_json_file)
def evaluate(self,
data: object_detector_dataloader.DataLoader,
batch_size: Optional[int] = None) -> Dict[str, float]:
"""Evaluates the model."""
batch_size = batch_size if batch_size else self.model_spec.batch_size
# Not to drop the smaller batch to evaluate the whole dataset.
self.model_spec.config.drop_remainder = False
ds = data.gen_dataset(self.model_spec, batch_size, is_training=False)
steps = (len(data) + batch_size - 1) // batch_size
# TODO(b/171449557): Upstream this to the parent class.
if steps <= 0:
raise ValueError('The size of the validation_data (%d) couldn\'t be '
'smaller than batch_size (%d). To solve this problem, '
'set the batch_size smaller or increase the size of the '
'validation_data.' % (len(data), batch_size))
eval_metrics = self.model_spec.evaluate(self.model, ds, steps,
data.annotations_json_file)
# Set back drop_remainder=True since it must be True during training.
# Otherwise it will fail.
self.model_spec.config.drop_remainder = True
return eval_metrics
def evaluate_tflite(
self, tflite_filepath: str,
data: object_detector_dataloader.DataLoader) -> Dict[str, float]:
"""Evaluate the TFLite model."""
ds = data.gen_dataset(self.model_spec, batch_size=1, is_training=False)
return self.model_spec.evaluate_tflite(tflite_filepath, ds, len(data),
data.annotations_json_file)
def _export_saved_model(self, saved_model_dir: str) -> None:
"""Saves the model to Tensorflow SavedModel."""
self.model_spec.export_saved_model(self.model, saved_model_dir)
def _export_tflite(
self,
tflite_filepath: str,
quantization_config: configs.QuantizationConfigType = 'default',
with_metadata: bool = True,
export_metadata_json_file: bool = False) -> None:
"""Converts the retrained model to tflite format and saves it.
Args:
tflite_filepath: File path to save tflite model.
quantization_config: Configuration for post-training quantization. If
'default', sets the `quantization_config` by default according to
`self.model_spec`. If None, exports the float tflite model without
quantization.
with_metadata: Whether the output tflite model contains metadata.
export_metadata_json_file: Whether to export metadata in json file. If
True, export the metadata in the same directory as tflite model.Used
only if `with_metadata` is True.
"""
if quantization_config == 'default':
quantization_config = self.model_spec.get_default_quantization_config(
self.representative_data)
self.model_spec.export_tflite(self.model, tflite_filepath,
quantization_config)
if with_metadata:
with tempfile.TemporaryDirectory() as temp_dir:
tf.compat.v1.logging.info(
'Label file is inside the TFLite model with metadata.')
label_filepath = os.path.join(temp_dir, 'labelmap.txt')
self._export_labels(label_filepath)
writer = metadata_writer.MetadataWriter.create_for_inference(
writer_utils.load_file(tflite_filepath),
[self.model_spec.config.mean_rgb],
[self.model_spec.config.stddev_rgb], [label_filepath])
writer_utils.save_file(writer.populate(), tflite_filepath)
if export_metadata_json_file:
metadata_json = writer.get_populated_metadata_json()
export_json_path = os.path.splitext(tflite_filepath)[0] + '.json'
with open(export_json_path, 'w') as f:
f.write(metadata_json)
def _export_labels(self, label_filepath: str) -> None:
"""Export labels to label_filepath."""
tf.compat.v1.logging.info('Saving labels in %s.', label_filepath)
num_classes = self.model_spec.config.num_classes
label_map = label_util.get_label_map(self.model_spec.config.label_map)
with tf.io.gfile.GFile(label_filepath, 'w') as f:
# Ignores label_map[0] that's the background. The labels in the label file
# for TFLite metadata should start from the actual labels without the
# background.
for i in range(num_classes):
label = label_map[i + 1] if i + 1 in label_map else '???'
f.write(label + '\n')
@classmethod
def create(cls,
train_data: object_detector_dataloader.DataLoader,
model_spec: object_detector_spec.EfficientDetModelSpec,
validation_data: Optional[
object_detector_dataloader.DataLoader] = None,
epochs: Optional[object_detector_dataloader.DataLoader] = None,
batch_size: Optional[int] = None,
train_whole_model: bool = False,
do_train: bool = True) -> T:
"""Loads data and train the model for object detection.
Args:
train_data: Training data.
model_spec: Specification for the model.
validation_data: Validation data. If None, skips validation process.
epochs: Number of epochs for training.
batch_size: Batch size for training.
train_whole_model: Boolean, False by default. If true, train the whole
model. Otherwise, only train the layers that are not match
`model_spec.config.var_freeze_expr`.
do_train: Whether to run training.
Returns:
An instance based on ObjectDetector.
"""
model_spec = ms.get(model_spec)
if epochs is not None:
model_spec.config.num_epochs = epochs
if batch_size is not None:
model_spec.config.batch_size = batch_size
if train_whole_model:
model_spec.config.var_freeze_expr = None
if compat.get_tf_behavior() not in model_spec.compat_tf_versions:
raise ValueError('Incompatible versions. Expect {}, but got {}.'.format(
model_spec.compat_tf_versions, compat.get_tf_behavior()))
object_detector = cls(model_spec, train_data.label_map, train_data)
if do_train:
tf.compat.v1.logging.info('Retraining the models...')
object_detector.train(train_data, validation_data, epochs, batch_size)
else:
object_detector.create_model()
return object_detector
# Shortcut function.
create = ObjectDetector.create
mm_export('object_detector.create').export_constant(__name__, 'create')