-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrestore.py
723 lines (635 loc) · 31.1 KB
/
restore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Logic for restoring checkpointed values for Trackables."""
import collections
from tensorflow.python.checkpoint import checkpoint_view
from tensorflow.python.checkpoint import functional_saver
from tensorflow.python.checkpoint import save_util_v1
from tensorflow.python.checkpoint import saveable_compat
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import registration
from tensorflow.python.trackable import base
from tensorflow.python.trackable import constants
from tensorflow.python.trackable import python_state
from tensorflow.python.trackable import trackable_utils
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import object_identity
class CheckpointPosition(object):
"""Indicates a position within a `_CheckpointRestoreCoordinator`."""
__slots__ = ["_checkpoint", "_proto_id", "skip_restore"]
def __init__(self, checkpoint, proto_id):
"""Specify an object within a checkpoint.
Args:
checkpoint: A _CheckpointRestoreCoordinator object.
proto_id: The index of this object in TrackableObjectGraph.nodes.
"""
self._checkpoint = checkpoint
self._proto_id = proto_id
# This may be set to True if the registered saver cannot be used with this
# object.
self.skip_restore = False
def restore(self, trackable, reader=None):
"""Restore this value into `trackable`."""
with ops.init_scope():
if self.bind_object(trackable):
# This object's correspondence with a checkpointed object is new, so
# process deferred restorations for it and its dependencies.
restore_ops = self._restore_descendants(reader)
if restore_ops:
self._checkpoint.new_restore_ops(restore_ops)
def bind_object(self, trackable):
"""Set a checkpoint<->object correspondence.
Args:
trackable: The object to record a correspondence for.
Returns:
True if this is a new assignment, False if this object has already been
mapped to a checkpointed `Object` proto.
Raises:
AssertionError: If another object is already bound to the `Object` proto.
"""
checkpoint = self.checkpoint
checkpoint.all_python_objects.add(trackable)
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
checkpoint.matched_proto_ids.add(self._proto_id)
if current_assignment is None:
checkpoint.object_by_proto_id[self._proto_id] = trackable
return True # New assignment
else:
# The object was already mapped for this checkpoint load, which means
# we don't need to do anything besides check that the mapping is
# consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not trackable:
logging.warning(
"Inconsistent references when loading the checkpoint into this "
"object graph. For example, in the saved checkpoint object, "
"`model.layer.weight` and `model.layer_copy.weight` reference the "
"same variable, while in the current object these are two different"
" variables. The referenced variables are:"
f"({current_assignment} and {trackable}).")
return False # Not a new assignment
def is_simple_variable(self):
"""Determine whether this value is restorable with a Tensor initializer."""
attributes = self.object_proto.attributes
return (len(attributes) == 1 and
attributes[0].name == constants.VARIABLE_VALUE_KEY and
not self.object_proto.children)
def value_tensors(self, shape_and_slices=None):
"""Create value `Tensor`s for this object's attributes.
Does not require that the Python object has been created. Used for
restore-on-create when executing eagerly.
Args:
shape_and_slices: A dict mapping from object attribute names to a shape
and slice string that will be passed to a RestoreV2 op. If the dict is
None or if an object attribute is not in the dict, the full tensor will
be restored.
Returns:
A dictionary mapping from object attribute names to `Tensor`s.
"""
value_tensors = {}
for serialized_tensor in self.object_proto.attributes:
checkpoint_key = serialized_tensor.checkpoint_key
dtype = self._checkpoint.dtype_map[checkpoint_key]
base_type = dtype.base_dtype
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
with ops.init_scope():
with ops.device(io_device):
# Run the restore itself on the io_device(CPU or specified).
if (shape_and_slices is not None and
serialized_tensor.name in shape_and_slices):
shape_and_slice = shape_and_slices[serialized_tensor.name]
else:
shape_and_slice = ""
value, = io_ops.restore_v2(
prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],
shape_and_slices=[shape_and_slice],
dtypes=[base_type],
name="%s_checkpoint_read" % (serialized_tensor.name,))
# Copy the value to the current device if necessary.
value_tensors[serialized_tensor.name] = array_ops.identity(value)
return value_tensors
def gather_ops_or_named_saveables(self):
"""Looks up or creates SaveableObjects which don't have cached ops.
Returns:
A tuple of (
existing_restore_ops: list,
named_saveables: dict,
python_positions: list,
registered_savers: dict)
"""
recorded_registered_saver = self.get_registered_saver_name()
if not (self.object_proto.attributes or recorded_registered_saver):
return [], {}, [], {}
existing_restore_ops = []
named_saveables = {}
python_positions = []
registered_savers = collections.defaultdict(dict)
saveable_factories = saveable_object_util.saveable_objects_from_trackable(
self.trackable)
saver_name = registration.get_registered_saver_name(self.trackable)
if recorded_registered_saver:
if not self.skip_restore:
name = self.object_proto.registered_saver.object_name
registered_savers[recorded_registered_saver][name] = self.trackable
# Else: Skip restoration of this Trackable. This skip only happens if the
# registered saver has enabled `option_restore`. Otherwise, an error would
# have been raised at `self.get_registered_saver_name()`.
elif saver_name:
# In this case, the checkpoint has a recorded serialized tensor but no
# registered saver, while the Trackable loading the checkpoint has
# migrated to the registered checkpoint functionality (TPUEmbedding is an
# example of this).
# Set the Trackable's object name to the first checkpoint key that is
# stored in checkpoint. If there is a use case that requires the other
# keys, then we can take another look at this.
registered_savers[saver_name] = {
self.object_proto.attributes[0].checkpoint_key: self.trackable
}
elif isinstance(self.trackable, python_state.PythonState):
python_positions.append(self)
elif saveable_factories.keys() == {
trackable_utils.SERIALIZE_TO_TENSORS_NAME
}:
existing_restore_ops, named_saveables = (
self._create_serialize_to_tensor_saveable(saveable_factories))
elif saveable_factories:
existing_restore_ops, named_saveables = (
self._create_saveables_by_attribute_name(saveable_factories))
else:
# If no registered savers were found, then it means that one or more
# serialized tensors were never used.
for serialized_tensor in self.object_proto.attributes:
self._checkpoint.unused_attributes.setdefault(
self._proto_id, []).append(serialized_tensor.name)
return (existing_restore_ops, named_saveables, python_positions,
registered_savers)
def _create_serialize_to_tensor_saveable(self, saveable_factories):
"""Creates a saveable using the _serialize_to_tensor method."""
# Extract the saveable name from the checkpoint key. This will be used as
# the cache key or the name to pass to the saveable factory.
suffix = saveable_compat.get_saveable_name(self.trackable) or ""
saveable_name = _extract_saveable_name(
self.object_proto.attributes[0].checkpoint_key) + suffix
# Try to find the cached saveable (only in graph mode).
if not context.executing_eagerly():
existing_op = self._checkpoint.restore_ops_by_name.get(
saveable_name, None)
if existing_op is not None:
return [existing_op], {}
saveables_cache = self._checkpoint.saveables_cache.setdefault(
self.trackable, {})
if saveable_name in saveables_cache:
return [], {saveable_name: saveables_cache[saveable_name]}
saveable = saveable_factories[trackable_utils.SERIALIZE_TO_TENSORS_NAME](
name=saveable_name)
if not context.executing_eagerly():
saveables_cache[saveable_name] = saveable
return [], {saveable_name: saveable}
def _create_saveables_by_attribute_name(self, saveable_factories):
"""Creates or caches SaveableObjects by matching the attribute names.
The attribute name keys in the `saveable_factories` is used to find the
corresponding attribute in the object proto. Attributes contain checkpoint
keys which are passed to the factory function to generate the
SaveableObject.
Args:
saveable_factories: a dict mapping attribute name to a callable factory
function that produces a SaveableObject.
Returns:
A tuple of (
existing_restore_ops: list,
named_saveables: dict)
"""
# Name saveables based on the name this object had when it was checkpointed.
named_saveables = {}
existing_restore_ops = []
# Forward compatibility code: when loading a future checkpoint, there may
# be multiple SerializedTensors mapped to a single saveable.
created_compat_names = set()
for serialized_tensor in self.object_proto.attributes:
if context.executing_eagerly():
existing_op = None
else:
existing_op = self._checkpoint.restore_ops_by_name.get(
serialized_tensor.checkpoint_key, None)
if existing_op is not None:
existing_restore_ops.append(existing_op)
continue
if any(serialized_tensor.name.startswith(name)
for name in created_compat_names):
continue # Saveable has already been created for this tensor.
# Only if we don't have cached ops for this SaveableObject, we'll see if
# the SaveableObject itself has been cached. If not, we'll make it, and
# either way we'll extract new ops from it (or if it has Python state to
# restore, we'll run that).
saveables_cache = self._checkpoint.saveables_cache
if saveables_cache is None:
# No SaveableObject caching when executing eagerly.
saveable = None
else:
# If we've already created and cached a SaveableObject for this
# attribute, we can re-use it to avoid re-creating some ops when graph
# building.
saveable_list = saveables_cache.get(self.trackable,
{}).get(serialized_tensor.name,
(None,))
if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject.
saveable, = saveable_list
else:
# Don't use cached SaveableObjects for partitioned variables, which is
# the only case where we'd have a list of SaveableObjects. Op caching
# will catch them.
saveable = None
if saveable is not None:
# The name of this attribute has changed, so we need to re-generate
# the SaveableObject.
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
del saveables_cache[self.trackable]
if saveable is None:
# If there was no cached SaveableObject, create one.
# Use the name to check if the Python object has the same attribute.
saveable = _get_saveable_from_factory(saveable_factories,
serialized_tensor,
created_compat_names)
if saveable is None:
# Purposefully does not throw an exception if attributes have been
# added or deleted. Stores unused attributes so an exception can be
# raised if the user decides to check that everything in the
# checkpoint was loaded.
self._checkpoint.unused_attributes.setdefault(
self._proto_id, []).append(serialized_tensor.name)
continue
if saveables_cache is not None:
saveables_cache.setdefault(self.trackable,
{})[serialized_tensor.name] = [saveable]
named_saveables[serialized_tensor.checkpoint_key] = saveable
return existing_restore_ops, named_saveables
def restore_ops(self, reader=None):
"""Create or fetch restore ops for this object's attributes.
Requires that the `Trackable` Python object has been bound to an object
ID in the checkpoint.
Args:
reader: A `CheckpointReader`. If None, a new instance will be created.
Returns:
A list of operations when graph building, or an empty list when executing
eagerly.
"""
if self._has_registered_saver():
raise ValueError("Unable to run individual checkpoint restore for objects"
" with registered savers.")
(restore_ops, tensor_saveables, python_positions,
_) = self.gather_ops_or_named_saveables()
restore_ops.extend(
self._checkpoint.restore_saveables(
tensor_saveables, python_positions, reader=reader))
return restore_ops
@property
def checkpoint(self):
return self._checkpoint
@property
def trackable(self):
return self._checkpoint.object_by_proto_id[self._proto_id]
@property
def object_proto(self):
return self._checkpoint.object_graph_proto.nodes[self._proto_id]
@property
def proto_id(self):
return self._proto_id
@property
def restore_uid(self):
return self._checkpoint.restore_uid
def __repr__(self):
return repr(self.object_proto)
def value_shape(self):
"""The shape of the VARIABLE_VALUE tensor.
Returns:
If found a TensorShape object, otherwise None.
"""
for serialized_tensor in self.object_proto.attributes:
if serialized_tensor.name == constants.VARIABLE_VALUE_KEY:
return self._checkpoint.shape_map[serialized_tensor.checkpoint_key]
return None
def _has_registered_saver(self):
return bool(self.object_proto.registered_saver.name)
def get_registered_saver_name(self):
"""Returns the registered saver name defined in the Checkpoint."""
if self._has_registered_saver():
saver_name = self.object_proto.registered_saver.name
try:
registration.validate_restore_function(self.trackable, saver_name)
except ValueError as e:
if registration.get_strict_predicate_restore(saver_name):
raise e
self.skip_restore = True
return saver_name
return None
def create_slot_variable_position(self, optimizer_object, variable,
slot_variable_id, slot_name):
"""Generates CheckpointPosition for a slot variable.
Args:
optimizer_object: Optimizer that owns the slot variable.
variable: Variable associated with the slot variable.
slot_variable_id: ID of the slot variable.
slot_name: Name of the slot variable.
Returns:
If there is a slot variable in the `optimizer_object` that has not been
bound to the checkpoint, this function returns a tuple of (
new `CheckpointPosition` for the slot variable,
the slot variable itself).
"""
slot_variable_position = CheckpointPosition(
checkpoint=self.checkpoint, proto_id=slot_variable_id)
# pylint: disable=protected-access
slot_variable = optimizer_object._create_or_restore_slot_variable(
slot_variable_position=slot_variable_position,
variable=variable,
slot_name=slot_name)
# pylint: enable=protected-access
if (slot_variable is not None and
slot_variable_position.bind_object(slot_variable)):
return slot_variable_position, slot_variable
else:
return None, None
def create_child_position(self, node_id):
return CheckpointPosition(checkpoint=self.checkpoint, proto_id=node_id)
def _restore_descendants(self, reader=None):
"""Restore the bound Trackable and dependencies (may be deferred)."""
# Attempt a breadth-first traversal, since presumably the user has more
# control over shorter paths. If we don't have all of the dependencies at
# this point, the end result is not breadth-first (since other deferred
# traversals will happen later).
# You may be wondering why elements in the `visit_queue` are tuples that
# contains both CheckpointPositions and their Trackable. The reason is that
# Optimizers will not keep a strong reference to slot vars for
# ShardedVariables. The slot variable must be kept in memory until the
# restore saveables have been created.
visit_queue = collections.deque([(self, self.trackable)])
restore_ops = []
tensor_saveables = {}
python_positions = []
registered_savers = collections.defaultdict(dict)
while visit_queue:
current_position, _ = visit_queue.popleft()
# Restore using the ops defined in a Saveable or registered function.
(new_restore_ops, new_tensor_saveables, new_python_positions,
new_registered_savers) = current_position._single_restore() # pylint: disable=protected-access
restore_ops.extend(new_restore_ops)
tensor_saveables.update(new_tensor_saveables)
python_positions.extend(new_python_positions)
for saver_name, trackable_map in new_registered_savers.items():
registered_savers[saver_name].update(trackable_map)
# Pass the restoration to the dependencies.
_queue_children_for_restoration(current_position, visit_queue)
_queue_slot_variables(current_position, visit_queue)
restore_ops.extend(
current_position.checkpoint.restore_saveables(
tensor_saveables,
python_positions,
registered_savers,
reader=reader))
return restore_ops
def _single_restore(self):
"""Restores the trackable."""
trackable = self.trackable
trackable._maybe_initialize_trackable() # pylint: disable=protected-access
checkpoint = self.checkpoint
# If the UID of this restore is lower than our current update UID, we don't
# need to actually restore the object.
if checkpoint.restore_uid > trackable._update_uid: # pylint: disable=protected-access
restore_ops, tensor_saveables, python_positions, registered_savers = (
self.gather_ops_or_named_saveables())
trackable._update_uid = checkpoint.restore_uid # pylint: disable=protected-access
else:
restore_ops = ()
tensor_saveables = {}
python_positions = ()
registered_savers = {}
return restore_ops, tensor_saveables, python_positions, registered_savers
def restore_nodes(save_path, nodes_to_restore):
"""Restores nodes from a dict.
Requires that the `Trackable` Python object has been bound to an object
ID in the checkpoint.
Args:
save_path: a string represents path to the checkpoint.
nodes_to_restore: a dict maps `node_id` to `trackable` to be restored.
"""
if save_path is None:
raise ValueError("save_path cannot be empty.")
if not isinstance(nodes_to_restore, dict):
raise ValueError(
"Expecting a dictionary of node_id to Trackable for nodes_to_restore.")
ckpt_view = checkpoint_view.CheckpointView(save_path)
ckpt_view_descendants = ckpt_view.descendants()
for node_id, trackable in nodes_to_restore.items():
# node_id does not have a corresponding Checkpoint value.
if (node_id not in ckpt_view_descendants or
ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
node_id] is None):
raise ValueError(
f"The expected node_id: {node_id} to Trackable {trackable} to "
"restore does not exist in the checkpoint.")
# Trackable mapped to node_id to restore is empty.
if trackable is None or not isinstance(trackable, base.Trackable):
raise ValueError(
f"Expecting a valid Trackable to node_id: {node_id} but got "
f"trackable: {trackable}."
)
serialized_tensors = object_identity.ObjectIdentityDictionary()
for node_id, current_trackable in nodes_to_restore.items():
ckpt_contains_serialized_tensors = ckpt_view._object_graph_proto.nodes[ # pylint: disable=protected-access
node_id].attributes
node = ckpt_view._object_graph_proto.nodes[node_id] # pylint: disable=protected-access
trackable_has_serialize_to_tensor = saveable_object_util.trackable_has_serialize_to_tensor(
current_trackable)
if not trackable_has_serialize_to_tensor:
if not node.attributes:
if saveable_object_util.saveable_objects_from_trackable(
current_trackable):
raise ValueError(
f"Trackable {current_trackable} expects checkpointed values but "
"checkpoint does not contain serialized tensors for node_id: "
f"{node_id}.")
else:
continue
object_names = object_identity.ObjectIdentityDictionary()
object_names[current_trackable] = trackable_utils.extract_object_name(
node.attributes[0].checkpoint_key)
checkpoint_factory_map, _ = save_util_v1.get_checkpoint_factories_and_keys(
object_names, None)
saveable_objects = save_util_v1.generate_saveable_objects(
checkpoint_factory_map)[0]
if len(node.attributes) != len(saveable_objects):
raise ValueError("Size for saveable_objects for Trackable: "
f"{len(saveable_objects)} did not match the size for "
"serialized_tensors for checkpoint: "
f"{len(node.attributes)}.")
current_trackable = saveable_object_util.SaveableCompatibilityConverter(
current_trackable, saveable_objects)
serialized_tensors[
current_trackable] = current_trackable._serialize_to_tensors() # pylint: disable=protected-access
trackable_expects_ckpted_value = bool(serialized_tensors[current_trackable])
if trackable_expects_ckpted_value and not ckpt_contains_serialized_tensors:
raise ValueError(
f"Trackable {current_trackable} expects checkpointed values but "
"checkpoint does not contain serialized tensors for node_id: "
f"{node_id}.")
if not trackable_expects_ckpted_value and ckpt_contains_serialized_tensors:
raise ValueError(
f"Trackable {current_trackable} does not expect checkpointed "
"values but checkpoint contains serialized tensors: "
f"{ckpt_contains_serialized_tensors} for node_id: {node_id}.")
if len(node.attributes) != len(serialized_tensors[current_trackable]):
raise ValueError("Size for serialized_tensors for Trackable: "
f"{len(serialized_tensors[current_trackable])} did not "
"match size for serialized_tensors for checkpoint: "
f"{len(node.attributes)}.")
if not trackable_has_serialize_to_tensor:
functional_saver.MultiDeviceSaver(serialized_tensors).restore(save_path)
else:
# Converts attribute.name to attribute.checkpoint_key since that's what
# restore method is expecting. i.e., converts "a" to "/.ATTRIBUTES/a".
serialized_tensors_renamed = object_identity.ObjectIdentityDictionary()
serialized_tensors_renamed[current_trackable] = {}
for attribute in node.attributes:
name = attribute.name
checkpoint_key = attribute.checkpoint_key
serialized_tensors_renamed[current_trackable][
checkpoint_key] = serialized_tensors[current_trackable][name]
functional_saver.MultiDeviceSaver(serialized_tensors_renamed).restore(
save_path)
def _queue_children_for_restoration(checkpoint_position, visit_queue):
"""Queues the restoration of trackable's children or defers them."""
# pylint: disable=protected-access
trackable = checkpoint_position.trackable
for child in checkpoint_position.object_proto.children:
# trackable._lookup_dependency can be expensive so first check if this node
# already has an object correspondence. If so we skip this node.
correspondence = checkpoint_position.checkpoint.object_by_proto_id.get(
child.node_id, None
)
if correspondence is not None:
continue
child_position = checkpoint_position.create_child_position(child.node_id)
local_object = trackable._lookup_dependency(child.local_name)
child_proto = child_position.object_proto
if local_object is None:
# We don't yet have a dependency registered with this name. Save it
# in case we do.
if child_proto.HasField("has_checkpoint_values"):
has_value = child_proto.has_checkpoint_values.value
else:
# If the field is not set, do a simple check to see if the dependency
# has children and/or checkpointed values.
has_value = bool(
child_proto.children or child_proto.attributes or
child_proto.slot_variables or
child_proto.HasField("registered_saver"))
if has_value:
trackable._deferred_dependencies.setdefault(child.local_name,
[]).append(child_position)
else:
if child_position.bind_object(trackable=local_object):
# This object's correspondence is new, so dependencies need to be
# visited. Delay doing it so that we get a breadth-first dependency
# resolution order (shallowest paths first). The caller is responsible
# for emptying visit_queue.
visit_queue.append((child_position, local_object))
_DeferredSlotVariableRestoration = collections.namedtuple(
"_DeferredSlotVariableRestoration", [
"original_variable",
"slot_variable_id",
"slot_name",
])
def _queue_slot_variables(checkpoint_position, visit_queue):
"""Queues slot variables for restoration."""
trackable = checkpoint_position.trackable
checkpoint = checkpoint_position.checkpoint
for deferred_slot_restoration in (checkpoint.deferred_slot_restorations.pop(
checkpoint_position.proto_id, ())):
slot_variable_position, slot_variable = (
checkpoint_position.create_slot_variable_position(
trackable, deferred_slot_restoration.original_variable,
deferred_slot_restoration.slot_variable_id,
deferred_slot_restoration.slot_name))
if slot_variable_position is not None:
visit_queue.append((slot_variable_position, slot_variable))
for slot_restoration in checkpoint.slot_restorations.pop(
checkpoint_position.proto_id, ()):
optimizer_object = checkpoint.object_by_proto_id.get(
slot_restoration.optimizer_id, None)
if optimizer_object is None:
# The optimizer has not yet been created or tracked. Record in the
# checkpoint that the slot variables need to be restored when it is.
checkpoint.deferred_slot_restorations.setdefault(
slot_restoration.optimizer_id, []).append(
_DeferredSlotVariableRestoration(
original_variable=trackable,
slot_variable_id=slot_restoration.slot_variable_id,
slot_name=slot_restoration.slot_name))
# `optimizer_object` can be a `Checkpoint` when user only needs the
# attributes the optimizer holds, such as `iterations`. In those cases,
# it would not have the optimizer's `_create_or_restore_slot_variable`
# method.
elif hasattr(optimizer_object, "_create_or_restore_slot_variable"):
slot_variable_position, slot_variable = (
checkpoint_position.create_slot_variable_position(
optimizer_object, trackable, slot_restoration.slot_variable_id,
slot_restoration.slot_name))
if slot_variable_position is not None:
visit_queue.append((slot_variable_position, slot_variable))
def _extract_saveable_name(checkpoint_key):
# Substring the checkpoint key to the end of the "{...}.ATTRIBUTES/"
search_key = trackable_utils.OBJECT_ATTRIBUTES_NAME + "/"
return checkpoint_key[:checkpoint_key.index(search_key) + len(search_key)]
def _get_saveable_from_factory(saveable_factories, serialized_tensor,
created_compat_names):
"""Returns the saveable generated from the factory method."""
matched_factory = None
# The `expected_factory_name` is used to find the right saveable factory,
# while the `factory_input_name` is the value that is passed to the factory
# method to instantiate the SaveableObject.
expected_factory_name = serialized_tensor.name
factory_input_name = serialized_tensor.checkpoint_key
# Case 1: the name already exactly matches a key in saveable_factories.
if expected_factory_name in saveable_factories:
matched_factory = saveable_factories[expected_factory_name]
# Case 2: (Forward compat) The serialized name is composed of
# "factory_name" + "SUFFIX". Get the matching factory name.
if matched_factory is None:
for factory_name, factory in saveable_factories.items():
if expected_factory_name.startswith(factory_name):
if matched_factory is not None:
# This condition is met in the extreme edge case where the object
# returns two saveable factories with similar names. This is very
# unlikely because there zero objects inside TensorFlow that use
# more than one saveable factory.
raise ValueError("Forward compatibility load error: Unable to load "
"checkpoint saved in future version of TensorFlow. "
"Please update your version of TensorFlow to the "
"version in which the checkpoint was saved.")
matched_factory = factory
factory_input_name = _extract_saveable_name(
serialized_tensor.checkpoint_key) + factory_name
created_compat_names.add(factory_name)
if callable(matched_factory):
return matched_factory(name=factory_input_name)
return matched_factory