/
saving_utils.py
323 lines (264 loc) · 12.2 KB
/
saving_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils related to keras model saving."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import os
import six
from tensorflow.python.eager import def_function
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import losses
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import base_layer_utils
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
def extract_model_metrics(model):
"""Convert metrics from a Keras model `compile` API to dictionary.
This is used for converting Keras models to Estimators and SavedModels.
Args:
model: A `tf.keras.Model` object.
Returns:
Dictionary mapping metric names to metric instances. May return `None` if
the model does not contain any metrics.
"""
if getattr(model, '_compile_metrics', None):
# TODO(psv/kathywu): use this implementation in model to estimator flow.
# We are not using model.metrics here because we want to exclude the metrics
# added using `add_metric` API.
return {m.name: m for m in model._compile_metric_functions} # pylint: disable=protected-access
return None
def model_input_signature(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
This is due to the Keras-enforced restriction that tensor inputs must be
passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object.
keep_original_batch_size: A boolean indicating whether we want to keep using
the original batch size or set it to None. Default is `False`, which means
that the batch dim of the returned input signature will always be set to
`None`.
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs. This list does not contain the `training` argument.
"""
input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size) # pylint: disable=protected-access
if input_specs is None:
return None
input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs,
collections_abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs
else:
return [input_specs]
def raise_model_input_error(model):
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' `.fit()` or `.predict()`. To manually set the shapes, call '
'`model.build(input_shape)`.'.format(model))
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
Args:
model: A Keras model.
input_signature: optional, a list of tf.TensorSpec objects specifying the
inputs to the model.
Returns:
A tf.function wrapping the model's call function with input signatures set.
Raises:
ValueError: if input signature cannot be inferred from the model.
"""
if input_signature is None:
if isinstance(model.call, def_function.Function):
input_signature = model.call.input_signature
if input_signature is None:
input_signature = model_input_signature(model)
if input_signature is None:
raise_model_input_error(model)
# TODO(mdan): Should the model's call be autographed by default?
@def_function.function(input_signature=input_signature, autograph=False)
def _wrapped_model(*args):
"""A concrete tf.function that wraps the model's call function."""
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
with base_layer_utils.call_context().enter(
model, inputs=inputs, build_graph=False, training=False, saving=True):
outputs = model(inputs, training=False)
# Outputs always has to be a flat dict.
output_names = model.output_names # Functional Model.
if output_names is None: # Subclassed Model.
from tensorflow.python.keras.engine import compile_utils # pylint: disable=g-import-not-at-top
output_names = compile_utils.create_pseudo_output_names(outputs)
outputs = nest.flatten(outputs)
return {name: output for name, output in zip(output_names, outputs)}
return _wrapped_model
def model_metadata(model, include_optimizer=True, require_config=True):
"""Returns a dictionary containing the model metadata."""
from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top
from tensorflow.python.keras.optimizer_v2 import optimizer_v2 # pylint: disable=g-import-not-at-top
model_config = {'class_name': model.__class__.__name__}
try:
model_config['config'] = model.get_config()
except NotImplementedError as e:
if require_config:
raise e
metadata = dict(
keras_version=str(keras_version),
backend=K.backend(),
model_config=model_config)
if model.optimizer and include_optimizer:
if isinstance(model.optimizer, optimizers.TFOptimizer):
logging.warning(
'TensorFlow optimizers do not '
'make it possible to access '
'optimizer attributes or optimizer state '
'after instantiation. '
'As a result, we cannot save the optimizer '
'as part of the model save file. '
'You will have to compile your model again after loading it. '
'Prefer using a Keras optimizer instead '
'(see keras.io/optimizers).')
elif model._compile_was_called: # pylint: disable=protected-access
training_config = model._get_compile_args(user_metrics=False) # pylint: disable=protected-access
training_config.pop('optimizer', None) # Handled separately.
metadata['training_config'] = _serialize_nested_config(training_config)
if isinstance(model.optimizer, optimizer_v2.RestoredOptimizer):
raise NotImplementedError(
'As of now, Optimizers loaded from SavedModel cannot be saved. '
'If you\'re calling `model.save` or `tf.keras.models.save_model`,'
' please set the `include_optimizer` option to `False`. For '
'`tf.saved_model.save`, delete the optimizer from the model.')
else:
optimizer_config = {
'class_name':
generic_utils.get_registered_name(model.optimizer.__class__),
'config':
model.optimizer.get_config()
}
metadata['training_config']['optimizer_config'] = optimizer_config
return metadata
def should_overwrite(filepath, overwrite):
"""Returns whether the filepath should be overwritten."""
# If file exists and should not be overwritten.
if not overwrite and os.path.isfile(filepath):
return ask_to_proceed_with_overwrite(filepath)
return True
def compile_args_from_training_config(training_config, custom_objects=None):
"""Return model.compile arguments from training config."""
if custom_objects is None:
custom_objects = {}
with generic_utils.CustomObjectScope(custom_objects):
optimizer_config = training_config['optimizer_config']
optimizer = optimizers.deserialize(optimizer_config)
# Recover losses.
loss = None
loss_config = training_config.get('loss', None)
if loss_config is not None:
loss = _deserialize_nested_config(losses.deserialize, loss_config)
# Recover metrics.
metrics = None
metrics_config = training_config.get('metrics', None)
if metrics_config is not None:
metrics = _deserialize_nested_config(_deserialize_metric, metrics_config)
# Recover weighted metrics.
weighted_metrics = None
weighted_metrics_config = training_config.get('weighted_metrics', None)
if weighted_metrics_config is not None:
weighted_metrics = _deserialize_nested_config(_deserialize_metric,
weighted_metrics_config)
sample_weight_mode = training_config['sample_weight_mode'] if hasattr(
training_config, 'sample_weight_mode') else None
loss_weights = training_config['loss_weights']
return dict(
optimizer=optimizer,
loss=loss,
metrics=metrics,
weighted_metrics=weighted_metrics,
loss_weights=loss_weights,
sample_weight_mode=sample_weight_mode)
def _deserialize_nested_config(deserialize_fn, config):
"""Deserializes arbitrary Keras `config` using `deserialize_fn`."""
def _is_single_object(obj):
if isinstance(obj, dict) and 'class_name' in obj:
return True # Serialized Keras object.
if isinstance(obj, six.string_types):
return True # Serialized function or string.
return False
if config is None:
return None
if _is_single_object(config):
return deserialize_fn(config)
elif isinstance(config, dict):
return {
k: _deserialize_nested_config(deserialize_fn, v)
for k, v in config.items()
}
elif isinstance(config, (tuple, list)):
return [_deserialize_nested_config(deserialize_fn, obj) for obj in config]
raise ValueError('Saved configuration not understood.')
def _serialize_nested_config(config):
"""Serialized a nested structure of Keras objects."""
def _serialize_fn(obj):
if callable(obj):
return generic_utils.serialize_keras_object(obj)
return obj
return nest.map_structure(_serialize_fn, config)
def _deserialize_metric(metric_config):
"""Deserialize metrics, leaving special strings untouched."""
from tensorflow.python.keras import metrics as metrics_module # pylint:disable=g-import-not-at-top
if metric_config in ['accuracy', 'acc', 'crossentropy', 'ce']:
# Do not deserialize accuracy and cross-entropy strings as we have special
# case handling for these in compile, based on model output shape.
return metric_config
return metrics_module.deserialize(metric_config)
def _enforce_names_consistency(specs):
"""Enforces that either all specs have names or none do."""
def _has_name(spec):
return hasattr(spec, 'name') and spec.name is not None
def _clear_name(spec):
spec = copy.deepcopy(spec)
if hasattr(spec, 'name'):
spec._name = None # pylint:disable=protected-access
return spec
flat_specs = nest.flatten(specs)
name_inconsistency = (
any(_has_name(s) for s in flat_specs) and
not all(_has_name(s) for s in flat_specs))
if name_inconsistency:
specs = nest.map_structure(_clear_name, specs)
return specs
def try_build_compiled_arguments(model):
if (not version_utils.is_v1_layer_or_model(model) and
model.outputs is not None):
try:
model.compiled_loss.build(model.outputs)
model.compiled_metrics.build(model.outputs, model.outputs)
except: # pylint: disable=bare-except
logging.warning(
'Compiled the loaded model, but the compiled metrics have yet to '
'be built. `model.compile_metrics` will be empty until you train '
'or evaluate the model.')