/
models.py
722 lines (632 loc) · 30 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Code for model cloning, plus model-related API entries.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizer_v1
from tensorflow.python.keras.engine import functional
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine import training_v1
from tensorflow.python.keras.engine.base_layer import AddMetric
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.input_layer import Input
from tensorflow.python.keras.engine.input_layer import InputLayer
from tensorflow.python.keras.saving import model_config
from tensorflow.python.keras.saving import save
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import version_utils
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
# API entries importable from `keras.models`:
Model = training.Model # pylint: disable=invalid-name
Sequential = sequential.Sequential # pylint: disable=invalid-name
Functional = functional.Functional # pylint: disable=invalid-name
save_model = save.save_model
load_model = save.load_model
model_from_config = model_config.model_from_config
model_from_yaml = model_config.model_from_yaml
model_from_json = model_config.model_from_json
# Callable used to clone a layer with weights preserved.
def share_weights(layer):
return layer
def _clone_layer(layer):
return layer.__class__.from_config(layer.get_config())
def _insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes):
"""Inserts ancillary layers into the model with the proper order."""
# Sort `AddMetric` layers so they agree with metrics_names.
metric_layers = [
layer for layer in ancillary_layers if isinstance(layer, AddMetric)
]
metric_layers.sort(key=lambda layer: metrics_names.index(layer.metric_name))
ancillary_layers = [
layer for layer in ancillary_layers if not isinstance(layer, AddMetric)
] + metric_layers
model._insert_layers(ancillary_layers, relevant_nodes=list(new_nodes))
def _make_new_nodes(nodes_by_depth, layer_fn, layer_map, tensor_map):
"""Uses the layers in `layer_map` to make new nodes based on `nodes_by_depth`.
Args:
nodes_by_depth: Provides structure information to create new nodes.
layer_fn: Function to clone layers.
layer_map: Map from layers in `model` to new layers.
tensor_map: Map from tensors in `model` to newly compute tensors.
Returns:
A set of new nodes. `layer_map` and `tensor_map` are updated.
"""
# Iterated over every node in the reference model, in depth order.
new_nodes = set()
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
# Recover the corresponding layer.
layer = node.outbound_layer
# Get or create layer.
if layer not in layer_map:
new_layer = layer_fn(layer)
layer_map[layer] = new_layer
layer = new_layer
else:
# Reuse previously cloned layer.
layer = layer_map[layer]
# Don't call InputLayer multiple times.
if isinstance(layer, InputLayer):
continue
# If all previous input tensors are available in tensor_map,
# then call node.inbound_layer on them.
if all(
tensor in tensor_map for tensor in nest.flatten(node.input_tensors)):
# Call layer.
args = nest.map_structure(lambda t: tensor_map.get(t, t),
node.call_args)
kwargs = nest.map_structure(lambda t: tensor_map.get(t, t),
node.call_kwargs)
output_tensors = layer(*args, **kwargs)
# Thread-safe way to keep track of what node was created.
first_output_tensor = nest.flatten(output_tensors)[0]
new_nodes.add(
layer._inbound_nodes[first_output_tensor._keras_history.node_index])
for x, y in zip(
nest.flatten(node.output_tensors), nest.flatten(output_tensors)):
tensor_map[x] = y
return new_nodes
def _clone_functional_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a functional `Model` instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
Input layers are always cloned.
Arguments:
model: Instance of `Model`.
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
layer_fn: callable to be applied on non-input layers in the model. By
default it clones the layer. Another example is to preserve the layer
to share the weights. This is required when we create a per-replica
copy of the model with distribution strategy; we want the weights to
be shared but still feed inputs separately so we create new input
layers.
Returns:
An instance of `Model` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights.
Raises:
ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
"""
if not isinstance(model, Model):
raise ValueError('Expected `model` argument '
'to be a `Model` instance, got ', model)
if isinstance(model, Sequential):
raise ValueError('Expected `model` argument '
'to be a functional `Model` instance, '
'got a `Sequential` instance instead:', model)
if not model._is_graph_network:
raise ValueError('Expected `model` argument '
'to be a functional `Model` instance, '
'but got a subclass model instead.')
new_input_layers = {} # Cache for created layers.
if input_tensors is not None:
# Make sure that all input tensors come from a Keras layer.
input_tensors = nest.flatten(input_tensors)
for i, input_tensor in enumerate(input_tensors):
original_input_layer = model._input_layers[i]
# Cache input layer. Create a new layer if the tensor is originally not
# from a Keras layer.
if not K.is_keras_tensor(input_tensor):
name = original_input_layer.name
input_tensor = Input(tensor=input_tensor,
name='input_wrapper_for_' + name)
newly_created_input_layer = input_tensor._keras_history.layer
new_input_layers[original_input_layer] = newly_created_input_layer
else:
new_input_layers[original_input_layer] = original_input_layer
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
model_configs, created_layers = _clone_layers_and_model_config(
model, new_input_layers, layer_fn)
# Reconstruct model from the config, using the cloned layers.
input_tensors, output_tensors, created_layers = (
functional.reconstruct_from_config(model_configs,
created_layers=created_layers))
metrics_names = model.metrics_names
model = Model(input_tensors, output_tensors, name=model.name)
# Layers not directly tied to outputs of the Model, such as loss layers
# created in `add_loss` and `add_metric`.
ancillary_layers = [
layer for layer in created_layers.values() if layer not in model.layers
]
# TODO(b/162887610): This may need to adjust the inbound node index if the
# created layers had already been used to define other models.
if ancillary_layers:
new_nodes = nest.flatten([
layer.inbound_nodes[1:]
if functional._should_skip_first_node(layer)
else layer.inbound_nodes for layer in created_layers.values()
])
_insert_ancillary_layers(model, ancillary_layers, metrics_names, new_nodes)
return model
def _clone_layers_and_model_config(model, input_layers, layer_fn):
"""Clones all layers, and returns the model config without serializing layers.
This function ensures that only the node graph is retrieved when getting the
model config. The `layer_fn` used to clone layers might not rely on
`layer.get_config()`, so some custom layers do not define `get_config`.
Trying to retrieve the config results in errors.
Args:
model: A Functional model.
input_layers: Dictionary mapping input layers in `model` to new input layers
layer_fn: Function used to clone all non-input layers.
Returns:
Model config object, and a dictionary of newly created layers.
"""
created_layers = {}
def _copy_layer(layer):
# Whenever the network config attempts to get the layer serialization,
# return a dummy dictionary.
if layer in input_layers:
created_layers[layer.name] = input_layers[layer]
elif layer in model._input_layers:
created_layers[layer.name] = InputLayer(**layer.get_config())
else:
created_layers[layer.name] = layer_fn(layer)
return {}
config = functional.get_network_config(
model, serialize_layer_fn=_copy_layer)
return config, created_layers
def _remove_ancillary_layers(model, layer_map, layers):
"""Removes and returns any ancillary layers from `layers` based on `model`.
Ancillary layers are part of the model topology but not used to compute the
model outputs, e.g., layers from `add_loss` and `add_metric`.
Args:
model: A Keras Model.
layer_map: A map to from layers in the `model` to those in `layers`.
layers: A list of all layers.
Returns:
Two lists of layers: (1) `layers` with the ancillary layers removed, and (2)
the ancillary layers.
"""
ancillary_layers = [] # Additional layers for computing losses and metrics.
if not model._is_graph_network:
return layers, ancillary_layers
# Ancillary layers are those with depth < 0.
depths = [depth for depth in model._nodes_by_depth.keys() if depth < 0]
depths.sort(reverse=True) # Order topologically from inputs to outputs.
for depth in depths:
for node in model._nodes_by_depth[depth]:
ancillary_layers.append(layer_map[node.outbound_layer])
return [l for l in layers if l not in ancillary_layers], ancillary_layers
def _clone_sequential_model(model, input_tensors=None, layer_fn=_clone_layer):
"""Clone a `Sequential` model instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
Arguments:
model: Instance of `Sequential`.
input_tensors: optional list of input tensors
to build the model upon. If not provided,
placeholders will be created.
layer_fn: callable to be applied on non-input layers in the model. By
default it clones the layer. Another example is to preserve the layer
to share the weights. This is required when we create a per-replica
copy of the model with distribution strategy; we want the weights to
be shared but still feed inputs separately so we create new input
layers.
Returns:
An instance of `Sequential` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights.
Raises:
ValueError: in case of invalid `model` argument value or `layer_fn`
argument value.
"""
if not isinstance(model, Sequential):
raise ValueError('Expected `model` argument '
'to be a `Sequential` model instance, '
'but got:', model)
if not callable(layer_fn):
raise ValueError('Expected `layer_fn` argument to be a callable.')
layers = [] # Layers needed to compute the model's outputs.
layer_map = {}
# Use model._layers to ensure that all layers are cloned. The model's layers
# property will exclude the initial InputLayer (if it exists) in the model,
# resulting in a different Sequential model structure.
for layer in model._layers:
if isinstance(layer, InputLayer) and input_tensors is not None:
# If input tensors are provided, the original model's InputLayer is
# overwritten with a different InputLayer.
continue
cloned_layer = (
_clone_layer(layer)
if isinstance(layer, InputLayer) else layer_fn(layer))
layers.append(cloned_layer)
layer_map[layer] = cloned_layer
layers, ancillary_layers = _remove_ancillary_layers(model, layer_map, layers)
if input_tensors is None:
cloned_model = Sequential(layers=layers, name=model.name)
elif len(generic_utils.to_list(input_tensors)) != 1:
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
else:
# Overwrite the original model's input layer.
if isinstance(input_tensors, tuple):
input_tensors = list(input_tensors)
x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history.layer
if isinstance(origin_layer, InputLayer):
cloned_model = Sequential(
layers=[origin_layer] + layers, name=model.name)
else:
raise ValueError('Cannot clone a `Sequential` model on top '
'of a tensor that comes from a Keras layer '
'other than an `InputLayer`. '
'Use the functional API instead.')
else:
input_tensor = Input(tensor=x, name='input_wrapper_for_' + str(x.name))
input_layer = input_tensor._keras_history.layer
cloned_model = Sequential(layers=[input_layer] + layers, name=model.name)
if not ancillary_layers:
return cloned_model
tensor_map = {} # Maps tensors from `model` to those in `cloned_model`.
for depth, cloned_nodes in cloned_model._nodes_by_depth.items():
nodes = model._nodes_by_depth[depth]
# This should be safe in a Sequential model. In an arbitrary network, you
# need to sort using the outbound layer of the node as a key.
for cloned_node, node in zip(cloned_nodes, nodes):
if isinstance(cloned_node.output_tensors, list):
for j, output_tensor in enumerate(cloned_node.output_tensors):
tensor_map[node.output_tensors[j]] = output_tensor
else:
tensor_map[node.output_tensors] = cloned_node.output_tensors
# Ancillary nodes have negative depth.
new_nodes = _make_new_nodes(
{
depth: nodes
for depth, nodes in model._nodes_by_depth.items()
if depth < 0
}, layer_fn, layer_map, tensor_map)
_insert_ancillary_layers(cloned_model, ancillary_layers, model.metrics_names,
new_nodes)
return cloned_model
@keras_export('keras.models.clone_model')
def clone_model(model, input_tensors=None, clone_function=None):
"""Clone any `Model` instance.
Model cloning is similar to calling a model on new inputs,
except that it creates new layers (and thus new weights) instead
of sharing the weights of the existing layers.
Arguments:
model: Instance of `Model`
(could be a functional model or a Sequential model).
input_tensors: optional list of input tensors or InputLayer objects
to build the model upon. If not provided,
placeholders will be created.
clone_function: Callable to be used to clone each layer in the target
model (except `InputLayer` instances). It takes as argument the layer
instance to be cloned, and returns the corresponding layer instance to
be used in the model copy. If unspecified, this callable defaults to
the following serialization/deserialization function:
`lambda layer: layer.__class__.from_config(layer.get_config())`.
By passing a custom callable, you can customize your copy of the
model, e.g. by wrapping certain layers of interest (you might want to
replace all `LSTM` instances with equivalent
`Bidirectional(LSTM(...))` instances, for example).
Returns:
An instance of `Model` reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights. The cloned model might behave
differently from the original model if a custom clone_function
modifies the layer.
Raises:
ValueError: in case of invalid `model` argument value.
"""
if clone_function is None:
clone_function = _clone_layer
if isinstance(model, Sequential):
return _clone_sequential_model(
model, input_tensors=input_tensors, layer_fn=clone_function)
else:
return _clone_functional_model(
model, input_tensors=input_tensors, layer_fn=clone_function)
# "Clone" a subclassed model by reseting all of the attributes.
def _in_place_subclassed_model_reset(model):
"""Substitute for model cloning that works for subclassed models.
Subclassed models cannot be cloned because their topology is not serializable.
To "instantiate" an identical model in a new TF graph, we reuse the original
model object, but we clear its state.
After calling this function on a model instance, you can use the model
instance as if it were a model clone (in particular you can use it in a new
graph).
This method clears the state of the input model. It is thus destructive.
However the original state can be restored fully by calling
`_in_place_subclassed_model_state_restoration`.
Args:
model: Instance of a Keras model created via subclassing.
Raises:
ValueError: In case the model uses a subclassed model as inner layer.
"""
assert not model._is_graph_network # Only makes sense for subclassed networks
# Select correct base class for new Model.
version_utils.swap_class(model.__class__, training.Model, training_v1.Model,
ops.executing_eagerly_outside_functions())
# Retrieve all layers tracked by the model as well as their attribute names
attributes_cache = {}
for name in dir(model):
# Skip the check of methods in tf.Module since they basically
# recursively query all the other attributes within same module.
if name == 'submodules':
continue
try:
value = getattr(model, name)
except (AttributeError, ValueError, TypeError):
continue
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model.layers
if hasattr(value, 'layers') and value.layers:
raise ValueError('We do not support the use of nested layers '
'in `model_to_estimator` at this time. Found nested '
'layer: %s' % value)
elif isinstance(
value, (list, tuple)) and name not in ('layers', '_layers', 'metrics',
'_compile_metric_functions',
'_output_loss_metrics'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
'attributes in subclassed models used with '
'`model_to_estimator` at this time. Found list '
'model: %s' % name)
# Replace layers on the model with fresh layers
layers_to_names = {value: key for key, value in attributes_cache.items()}
original_layers = model._layers[:]
setattr_tracking = model._setattr_tracking
model._setattr_tracking = False
model._layers = []
for layer in original_layers: # We preserve layer order.
config = layer.get_config()
# This will not work for nested subclassed models used as layers.
# This would be theoretically possible to support, but would add complexity.
# Only do it if users complain.
if isinstance(layer, training.Model) and not layer._is_graph_network:
raise ValueError('We do not support the use of nested subclassed models '
'in `model_to_estimator` at this time. Found nested '
'model: %s' % layer)
fresh_layer = layer.__class__.from_config(config)
name = layers_to_names[layer]
setattr(model, name, fresh_layer)
model._layers.append(fresh_layer)
# Cache original model build attributes (in addition to layers)
if (not hasattr(model, '_original_attributes_cache') or
model._original_attributes_cache is None):
if model.built:
attributes_to_cache = [
'inputs',
'outputs',
'total_loss',
'optimizer',
'train_function',
'test_function',
'predict_function',
'_training_endpoints',
'_collected_trainable_weights',
'_feed_inputs',
'_feed_input_names',
'_feed_input_shapes',
]
for name in attributes_to_cache:
attributes_cache[name] = getattr(model, name)
model._original_attributes_cache = attributes_cache
_reset_build_compile_trackers(model)
model._setattr_tracking = setattr_tracking
def _reset_build_compile_trackers(model):
"""Reset state trackers for model.
Note that we do not actually zero out attributes such as optimizer,
but instead rely on the expectation that all of the attrs will be
over-written on calling build/compile/etc. This is somewhat fragile,
insofar as we check elsewhere for the presence of these attributes as
evidence of having been built/compiled/etc. Pending a better way to do this,
we reset key attributes here to allow building and compiling.
Args:
model: the model that is being reset
"""
# Reset build state
model.built = False
model.inputs = None
model.outputs = None
# Reset compile state
model._is_compiled = False # pylint:disable=protected-access
if not ops.executing_eagerly_outside_functions():
model._v1_compile_was_called = False
model.optimizer = None
def in_place_subclassed_model_state_restoration(model):
"""Restores the original state of a model after it was "reset".
This undoes this action of `_in_place_subclassed_model_reset`, which is called
in `clone_and_build_model` if `in_place_reset` is set to True.
Args:
model: Instance of a Keras model created via subclassing, on which
`_in_place_subclassed_model_reset` was previously called.
"""
assert not model._is_graph_network
# Restore layers and build attributes
if (hasattr(model, '_original_attributes_cache') and
model._original_attributes_cache is not None):
# Models have sticky attribute assignment, so we want to be careful to add
# back the previous attributes and track Layers by their original names
# without adding dependencies on "utility" attributes which Models exempt
# when they're constructed.
setattr_tracking = model._setattr_tracking
model._setattr_tracking = False
model._layers = []
for name, value in model._original_attributes_cache.items():
setattr(model, name, value)
if isinstance(value, Layer):
model._layers.append(value)
model._original_attributes_cache = None
model._setattr_tracking = setattr_tracking
else:
# Restore to the state of a never-called model.
_reset_build_compile_trackers(model)
def clone_and_build_model(
model, input_tensors=None, target_tensors=None, custom_objects=None,
compile_clone=True, in_place_reset=False, optimizer_iterations=None,
optimizer_config=None):
"""Clone a `Model` and build/compile it with the same settings used before.
This function can be be run in the same graph or in a separate graph from the
model. When using a separate graph, `in_place_reset` must be `False`.
Note that, currently, the clone produced from this function may not work with
TPU DistributionStrategy. Try at your own risk.
Args:
model: `tf.keras.Model` object. Can be Functional, Sequential, or
sub-classed.
input_tensors: Optional list or dictionary of input tensors to build the
model upon. If not provided, placeholders will be created.
target_tensors: Optional list of target tensors for compiling the model. If
not provided, placeholders will be created.
custom_objects: Optional dictionary mapping string names to custom classes
or functions.
compile_clone: Boolean, whether to compile model clone (default `True`).
in_place_reset: Boolean, whether to reset the model in place. Only used if
the model is a subclassed model. In the case of a subclassed model,
this argument must be set to `True` (default `False`). To restore the
original model, use the function
`in_place_subclassed_model_state_restoration(model)`.
optimizer_iterations: An iterations variable that will be incremented by the
optimizer if the clone is compiled. This argument is used when a Keras
model is cloned into an Estimator model function, because Estimators
create their own global step variable.
optimizer_config: Optimizer config dictionary or list of dictionary
returned from `get_config()`. This argument should be defined if
`clone_and_build_model` is called in a different graph or session from
the original model, and the optimizer is an instance of `OptimizerV2`.
Returns:
Clone of the model.
Raises:
ValueError: Cloning fails in the following cases
- cloning a subclassed model with `in_place_reset` set to False.
- compiling the clone when the original model has not been compiled.
"""
# Grab optimizer now, as we reset-in-place for subclassed models, but
# want to maintain access to the original optimizer.
orig_optimizer = model.optimizer
if compile_clone and not orig_optimizer:
raise ValueError(
'Error when cloning model: compile_clone was set to True, but the '
'original model has not been compiled.')
if compile_clone:
compile_args = model._get_compile_args() # pylint: disable=protected-access
# Allows this method to be robust to switching graph and eager classes.
model._get_compile_args = lambda: compile_args
with CustomObjectScope(custom_objects or {}):
if model._is_graph_network:
clone = clone_model(model, input_tensors=input_tensors)
elif isinstance(model, Sequential):
clone = clone_model(model, input_tensors=input_tensors)
if (not clone._is_graph_network and model._build_input_shape is not None):
if ops.executing_eagerly_outside_functions():
clone.build(model._build_input_shape)
else:
clone._set_inputs(
K.placeholder(
model._build_input_shape, dtype=model.inputs[0].dtype))
else:
try:
# Prefer clonining the model if serial/deserial logic is implemented for
# subclassed model.
clone = model.__class__.from_config(model.get_config())
except NotImplementedError:
logging.warning('This model is a subclassed model. Please implement '
'`get_config` and `from_config` to better support '
'cloning the model.')
if not in_place_reset:
raise ValueError(
'This model is a subclassed model. '
'Such a model cannot be cloned, but there is a workaround where '
'the model is reset in-place. To use this, please set the '
'argument `in_place_reset` to `True`. This will reset the '
'attributes in the original model. To restore the attributes, '
'call `in_place_subclassed_model_state_restoration(model)`.')
clone = model
_in_place_subclassed_model_reset(clone)
if input_tensors is not None:
if isinstance(input_tensors, (list, tuple)) and len(input_tensors) == 1:
input_tensors = input_tensors[0]
clone._set_inputs(input_tensors)
if compile_clone:
if isinstance(orig_optimizer, optimizer_v1.TFOptimizer):
optimizer = optimizer_v1.TFOptimizer(
orig_optimizer.optimizer, optimizer_iterations)
K.track_tf_optimizer(optimizer)
else:
if not isinstance(orig_optimizer, (tuple, list)):
orig_optimizer = [orig_optimizer]
if optimizer_config is None:
optimizer = [
opt.__class__.from_config(opt.get_config())
for opt in orig_optimizer
]
elif isinstance(optimizer_config, dict):
optimizer = [orig_optimizer[0].__class__.from_config(optimizer_config)]
else:
# optimizer config is list of dict, same order as orig_optimizer.
optimizer = [
opt.__class__.from_config(opt_config)
for (opt, opt_config) in zip(orig_optimizer, optimizer_config)
]
if optimizer_iterations is not None:
for opt in optimizer:
opt.iterations = optimizer_iterations
if len(optimizer) == 1:
optimizer = optimizer[0]
compile_args['optimizer'] = optimizer
if target_tensors is not None:
compile_args['target_tensors'] = target_tensors
# Ensure Metric objects in new model are separate from existing model.
compile_args['metrics'] = metrics_module.clone_metrics(
compile_args['metrics'])
compile_args['weighted_metrics'] = metrics_module.clone_metrics(
compile_args['weighted_metrics'])
clone.compile(**compile_args)
return clone