/
checkpoint.py
2753 lines (2343 loc) · 112 KB
/
checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Utilities for saving/loading Trackable objects."""
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import abc
import collections
import copy
import functools
import glob
import inspect
import os
import threading
import time
import weakref
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.checkpoint import async_checkpoint_helper
from tensorflow.python.checkpoint import checkpoint_context
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.checkpoint import checkpoint_options
from tensorflow.python.checkpoint import functional_saver
from tensorflow.python.checkpoint import graph_view as graph_view_lib
from tensorflow.python.checkpoint import restore as restore_lib
from tensorflow.python.checkpoint import save_util
from tensorflow.python.checkpoint import save_util_v1
from tensorflow.python.checkpoint import util
from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.eager import monitoring
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variable_v1
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import path_helpers
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.trackable import autotrackable
from tensorflow.python.trackable import base
from tensorflow.python.trackable import data_structures
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.saving import saveable_object as saveable_object_lib
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util import object_identity
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
# The callable that provide Keras default session that is needed for saving.
_SESSION_PROVIDER = None
# Captures the timestamp of the first Checkpoint instantiation or end of a write
# operation. Can be accessed by multiple Checkpoint instances.
_END_TIME_OF_LAST_WRITE = None
_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
# API labels for cell names used in checkpoint metrics.
_CHECKPOINT_V1 = "checkpoint_v1"
_CHECKPOINT_V2 = "checkpoint_v2"
# Async thread used for asynchronous checkpoint.
_ASYNC_CHECKPOINT_THREAD = None
def _get_duration_microseconds(start_time_seconds, end_time_seconds):
if end_time_seconds < start_time_seconds:
# Avoid returning negative value in case of clock skew.
return 0
return round((end_time_seconds - start_time_seconds) * 1000000)
@tf_export("__internal__.tracking.register_session_provider", v1=[])
def register_session_provider(session_provider):
global _SESSION_PROVIDER
# TODO(scottzhu): Change it back to only allow one time setting for session
# provider once we finished the keras repo split.
# if _SESSION_PROVIDER is None:
_SESSION_PROVIDER = session_provider
def get_session():
# Prefer TF's default session since get_session from Keras has side-effects.
session = ops.get_default_session()
if session is None:
global _SESSION_PROVIDER
if _SESSION_PROVIDER is not None:
session = _SESSION_PROVIDER() # pylint: disable=not-callable
return session
def _get_checkpoint_size(prefix):
"""Calculates filesize of checkpoint based on prefix."""
size = 0
# Gather all files beginning with prefix (.index plus sharded data files).
files = glob.glob("{}*".format(prefix))
for file in files:
# Use TensorFlow's C++ FileSystem API.
size += metrics.CalculateFileSize(file)
return size
def _execute_callbacks(callbacks, save_path):
"""Executes a list of callback functions, providing `save_path` if needed."""
for callback in callbacks:
num_params = len(inspect.signature(callback).parameters)
if num_params == 0:
callback()
elif num_params == 1:
callback(save_path)
else:
raise AssertionError(
"Callback functions for checkpoint are required to have 0 or 1"
f"parameters, but this has {num_params} parameters: {callback}"
)
class ObjectGraphProtoPrettyPrinter:
"""Lazily traverses an object graph proto to pretty print names.
If no calls to `node_names` are made this object has no performance
overhead. On the other hand, it will only traverse the object graph once, so
repeated naming is cheap after the first.
"""
__slots__ = ["_object_graph_proto", "_node_name_cache"]
def __init__(self, object_graph_proto):
self._object_graph_proto = object_graph_proto
self._node_name_cache = None
@property
def node_names(self):
"""Lazily creates a mapping from node id to ("path", "to", "root")."""
if self._node_name_cache is not None:
return self._node_name_cache
path_to_root = {}
path_to_root[0] = ("(root)",)
to_visit = collections.deque([0])
while to_visit:
node_id = to_visit.popleft()
obj = self._object_graph_proto.nodes[node_id]
for child in obj.children:
if child.node_id not in path_to_root:
path_to_root[child.node_id] = (
path_to_root[node_id] + (child.local_name,))
to_visit.append(child.node_id)
node_names = {}
for node_id, path_to_root in path_to_root.items():
node_names[node_id] = ".".join(path_to_root)
for node_id, node in enumerate(self._object_graph_proto.nodes):
for slot_reference in node.slot_variables:
node_names[slot_reference.slot_variable_node_id] = (
f"{node_names[node_id]}'s state '{slot_reference.slot_name}' for "
f"{node_names[slot_reference.original_variable_node_id]}")
self._node_name_cache = node_names
return node_names
class _CheckpointRestoreCoordinatorDeleter:
"""Deleter to avoid overriding _CheckpointRestoreCoordinator.__del__()."""
__slots__ = [
"expect_partial", "object_graph_proto", "matched_proto_ids",
"unused_attributes"
]
def __init__(self, expect_partial, object_graph_proto, matched_proto_ids,
unused_attributes):
self.expect_partial = expect_partial
self.object_graph_proto = object_graph_proto
self.matched_proto_ids = matched_proto_ids
self.unused_attributes = unused_attributes
def set_expect_partial(self, expect_partial):
self.expect_partial = expect_partial
def __del__(self):
if self.expect_partial:
return
if logging is None:
# The logging module may have been unloaded when __del__ is called.
log_fn = print
else:
log_fn = logging.warning
unused_nodes_in_checkpoint = []
unrestored_attributes_in_object = []
pretty_printer = ObjectGraphProtoPrettyPrinter(self.object_graph_proto)
for node_id, node in enumerate(self.object_graph_proto.nodes):
if not node.attributes:
continue
if node_id not in self.matched_proto_ids:
unused_nodes_in_checkpoint.append(pretty_printer.node_names[node_id])
for node_id, attribute_name in self.unused_attributes.items():
unrestored_attributes_in_object.append((
pretty_printer.node_names[node_id], attribute_name))
if unused_nodes_in_checkpoint or unrestored_attributes_in_object:
# pylint:disable=line-too-long
log_fn("Detecting that an object or model or tf.train.Checkpoint is being"
" deleted with unrestored values. See the following logs for the "
"specific values in question. To silence these warnings, use "
"`status.expect_partial()`. See "
"https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restore"
"for details about the status object returned by the restore "
"function.")
# pylint:enable=line-too-long
for node_path in unused_nodes_in_checkpoint:
log_fn("Value in checkpoint could not be found in the restored object: "
f"{node_path}")
for node_path, attr in unrestored_attributes_in_object:
log_fn("An attribute in the restored object could not be found in the "
f"checkpoint. Object: {node_path}, attribute: {attr}")
class _CheckpointRestoreCoordinator:
"""Holds the status of an object-based checkpoint load."""
def __init__(self, object_graph_proto, save_path, save_path_tensor, reader,
restore_op_cache, graph_view, options, saveables_cache):
"""Specify the checkpoint being loaded.
Args:
object_graph_proto: The TrackableObjectGraph protocol buffer associated
with this checkpoint.
save_path: A string, the path to the checkpoint, as returned by
`tf.train.latest_checkpoint`.
save_path_tensor: A string `Tensor` which contains or will be fed the save
path.
reader: A `CheckpointReader` for `save_path`. If None,
`_CheckpointRestoreCoordinator` will initialize one itself.
restore_op_cache: A dictionary shared between
`_CheckpointRestoreCoordinator`s for the same Python objects, used to
look up restore ops by name to avoid re-creating them across multiple
`restore()` calls.
graph_view: A graph_view_lib.ObjectGraphView object for the restored
objects.
options: A CheckpointOptions object.
saveables_cache: An optional cache storing previously created
SaveableObjects created for each Trackable. Maps Trackables to a
dictionary of attribute names to Trackable.
"""
self.options = options
self.object_graph_proto = object_graph_proto
self.restore_uid = ops.uid()
# Maps from proto ids to lists of attributes which were in the checkpoint
# but not loaded into any object, for error checking.
self.unused_attributes = {}
# Dictionary mapping from an id in the protocol buffer flat array to
# Trackable Python objects. This mapping may be deferred if a
# checkpoint is restored before all dependencies have been tracked. Uses
# weak references so that partial restorations don't create reference cycles
# (as objects with deferred dependencies will generally have references to
# this object).
self.object_by_proto_id = weakref.WeakValueDictionary()
self.matched_proto_ids = set()
# A set of all Python objects we've seen as dependencies, even if we didn't
# use them (for example because of inconsistent references when
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
self.all_python_objects = object_identity.ObjectIdentityWeakSet()
self.save_path_tensor = save_path_tensor
self.save_path_string = save_path
self.dtype_map = reader.get_variable_to_dtype_map()
self.shape_map = reader.get_variable_to_shape_map()
# A NewCheckpointReader for the most recent checkpoint, for streaming Python
# state restoration.
# When graph building, contains a list of ops to run to restore objects from
# this checkpoint.
self.restore_ops = []
self.restore_ops_by_name = restore_op_cache
self.graph_view = graph_view
self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
# regular variables have already been created, and only for optimizer
# objects which have not yet been created/tracked.
self.deferred_slot_restorations = {}
# A mapping from variable proto ids to lists of slot variables to be
# restored when the variable is created/tracked. These get shifted over to
# deferred_slot_restorations if the optimizer hasn't been created when that
# happens.
self.slot_restorations = collections.defaultdict(list)
# Controls whether errors are printed in __del__ if some objects did not
# match.
self.expect_partial_attr = False
if not self.options.experimental_skip_slot_variables:
for node_index, node in enumerate(self.object_graph_proto.nodes):
for slot_reference in node.slot_variables:
# `node` refers to an `Optimizer`, since only these have slot
# variables.
self.slot_restorations[
slot_reference.original_variable_node_id
].append(
base._SlotVariableRestoration( # pylint: disable=protected-access
optimizer_id=node_index,
slot_variable_id=slot_reference.slot_variable_node_id,
slot_name=slot_reference.slot_name,
)
)
self._deleter = _CheckpointRestoreCoordinatorDeleter(
self.expect_partial_attr,
self.object_graph_proto,
self.matched_proto_ids,
self.unused_attributes)
self.saveables_cache = saveables_cache
@property
def expect_partial(self):
return self.expect_partial_attr
@expect_partial.setter
def expect_partial(self, expect_partial):
self.expect_partial_attr = expect_partial
self._deleter.set_expect_partial(expect_partial)
def new_restore_ops(self, new_ops):
self.restore_ops.extend(new_ops)
if self.new_restore_ops_callback:
self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
def restore_saveables(
self,
tensor_saveables,
python_positions,
registered_savers=None,
reader=None,
):
"""Run or build restore operations for SaveableObjects.
Args:
tensor_saveables: `SaveableObject`s which correspond to Tensors.
python_positions: List of CheckpointPositions bound to `PythonState`
objects which must be restored eagerly.
registered_savers: a dict mapping saver names-> object name -> Trackable.
reader: A `CheckpointReader`. If None, a new instance will be created.
Returns:
When graph building, a list of restore operations, either cached or newly
created, to restore `tensor_saveables`.
"""
if reader is None:
reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string)
restore_ops = []
# Eagerly run restorations for Python state.
for position in python_positions:
key = position.object_proto.attributes[0].checkpoint_key
position.trackable.deserialize(reader.get_tensor(key))
# If we have new SaveableObjects, extract and cache restore ops.
if tensor_saveables or registered_savers:
flat_saveables = saveable_object_util.validate_and_slice_inputs(
tensor_saveables)
new_restore_ops = functional_saver.MultiDeviceSaver.from_saveables(
flat_saveables,
registered_savers).restore(self.save_path_tensor, self.options)
if not context.executing_eagerly():
for name, restore_op in sorted(new_restore_ops.items()):
restore_ops.append(restore_op)
assert name not in self.restore_ops_by_name
self.restore_ops_by_name[name] = restore_op
return restore_ops
class _NameBasedRestoreCoordinator:
"""Keeps the status of a name-based checkpoint restore."""
def __init__(self, save_path, dtype_map=None):
self.save_path = save_path
self.dtype_map = dtype_map
# A map from trackable objects to unused attribute names. We don't have
# proto IDs when doing a name-based restore, so the map keys differ from
# those in _CheckpointRestoreCoordinator.
self.unused_attributes = object_identity.ObjectIdentityWeakKeyDictionary()
self.restore_uid = ops.uid()
def globally_named_object_attributes(self, trackable):
"""Create globally named SaveableObjects from attributes.
If an object's attribute has no global name specified (default construction
for the SaveableObject factory), records the failure in
`self.unused_attributes` (which can then be used to make status assertions
fail; see `NameBasedSaverStatus`).
Args:
trackable: An object to save.
Yields:
SaveableObjects for `trackable`'s attributes.
"""
for (
attribute_name,
saveable_factory,
) in saveable_object_util.saveable_objects_from_trackable(
trackable, tf1_saver=True,
).items():
if callable(saveable_factory):
try:
# This saveable object factory does not have a default name= argument,
# which means there's no way to save/restore it using a name-based
# checkpoint. Ignore the error now and make sure assert_consumed()
# fails.
saveable = saveable_factory()
except TypeError:
self.unused_attributes.setdefault(trackable,
[]).append(attribute_name)
continue
else:
saveable = saveable_factory
names_to_saveables = saveable_object_util.op_list_to_dict(
[saveable], convert_variable_to_tensor=False)
for name, op in names_to_saveables.items():
for saveable_object in saveable_object_util.saveable_objects_for_op(
op=op, name=name):
yield saveable_object
def eager_restore(self, trackable):
"""Runs restore ops for `trackable`'s attributes."""
# When graph building, we don't add any restore ops to the graph until
# run_restore_ops/initialize_or_restore on the status object for name-based
# checkpoints.
assert context.executing_eagerly()
for saveable in self.globally_named_object_attributes(trackable):
restored_tensors = []
tensor_missing = False
for spec in saveable.specs:
if spec.name in self.dtype_map:
with ops.device("cpu:0"):
restored, = io_ops.restore_v2(
prefix=self.save_path,
tensor_names=[spec.name],
shape_and_slices=[""],
dtypes=[self.dtype_map[spec.name]],
name="%s_checkpoint_read" % (spec.name,))
restored_tensors.append(array_ops.identity(restored))
else:
tensor_missing = True
if tensor_missing:
# Record that this variable didn't match so assertions will fail.
self.unused_attributes.setdefault(trackable, []).append(saveable.name)
else:
# Ignores values missing from the checkpoint, as with object-based
# restore. Status assertions can be used to check exact matches,
# although it's unlikely to ever happen for name-based checkpoints.
saveable.restore(
restored_tensors=restored_tensors, restored_shapes=None)
# TODO(allenl): If this ends up in a public API, consider adding LINT.If Change
# or consolidating the implementation with get_variable.
def _default_getter(name,
shape,
dtype,
initializer=None,
partition_info=None,
**kwargs):
"""A pared-down version of get_variable which does not reuse variables."""
dtype = dtypes.as_dtype(dtype)
shape_object = tensor_shape.as_shape(shape)
with ops.init_scope():
if initializer is None:
initializer, initializing_from_value = (
variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access
name=name,
shape=shape_object,
dtype=dtype))
else:
initializing_from_value = not callable(initializer)
# Same logic as get_variable
variable_dtype = dtype.base_dtype
if initializing_from_value:
if shape is not None:
raise ValueError("If initializer is a constant, do not specify shape.")
initial_value = initializer
else:
# Instantiate initializer if provided initializer is a type object.
if isinstance(initializer, type(init_ops.Initializer)):
initializer = initializer(dtype=dtype)
shape_list = None if shape is None else shape_object.as_list()
if "partition_info" in tf_inspect.getargspec(initializer).args:
initial_value = functools.partial(initializer,
shape_list,
dtype=dtype,
partition_info=partition_info)
else:
initial_value = functools.partial(initializer,
shape_list,
dtype=dtype)
return variable_v1.VariableV1(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
use_resource=True,
**kwargs)
def add_variable(trackable,
name,
shape=None,
dtype=dtypes.float32,
initializer=None,
trainable=True):
"""Add a variable to a Trackable with no scope influence."""
return trackable._add_variable_with_custom_getter( # pylint: disable=protected-access
name=name,
shape=shape,
dtype=dtype,
initializer=initializer,
getter=_default_getter,
trainable=trainable)
def object_metadata(save_path):
"""Retrieves information about the objects in a checkpoint.
Example usage:
```python
object_graph = tf.contrib.checkpoint.object_metadata(
tf.train.latest_checkpoint(checkpoint_directory))
ckpt_variable_names = set()
for node in object_graph.nodes:
for attribute in node.attributes:
ckpt_variable_names.add(attribute.full_name)
```
Args:
save_path: The path to the checkpoint, as returned by `save` or
`tf.train.latest_checkpoint`.
Returns:
A parsed `tf.contrib.checkpoint.TrackableObjectGraph` protocol buffer.
Raises:
ValueError: If an object graph was not found in the checkpoint.
"""
reader = py_checkpoint_reader.NewCheckpointReader(save_path)
try:
object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
raise ValueError(
f"The specified checkpoint \"{save_path}\" does not appear to be "
"object-based (saved with TF2) since it is missing the key "
f"\"{base.OBJECT_GRAPH_PROTO_KEY}\". Likely it was created with the "
"TF1 name-based saver and does not contain an object dependency graph.")
object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph())
object_graph_proto.ParseFromString(object_graph_string)
return object_graph_proto
def list_objects(root_trackable):
"""Traverse the object graph and list all accessible objects.
Looks for `Trackable` objects which are dependencies of
`root_trackable`. Includes slot variables only if the variable they are
slotting for and the optimizer are dependencies of `root_trackable`
(i.e. if they would be saved with a checkpoint).
Args:
root_trackable: A `Trackable` object whose dependencies should be flattened.
Returns:
A flat list of objects.
"""
return util.list_objects(graph_view_lib.ObjectGraphView(root_trackable))
def gather_initializers(root_trackable):
"""Traverse the object graph and find initialization ops.
Looks for `Trackable` objects which are dependencies of
`root_trackable` and which have an `initializer` property. Includes
initializers for slot variables only if the variable they are slotting for and
the optimizer are dependencies of `root_trackable` (i.e. if they would be
saved with a checkpoint).
Args:
root_trackable: A `Trackable` object to gather initializers for.
Returns:
A list of initialization ops.
"""
trackable_objects = list_objects(root_trackable)
return [
c.initializer
for c in trackable_objects
if hasattr(c, "initializer") and c.initializer is not None
]
@tf_contextlib.contextmanager
def capture_dependencies(template):
"""Capture variables created within this scope as `Template` dependencies.
Requires that `template.variable_scope` is active.
This scope is intended as a compatibility measure, allowing a trackable
object to add dependencies on variables created in a block of code which is
not aware of object-based saving (and instead uses variable names
heavily). This is how `Template` objects add dependencies on variables and
sub-`Template`s. Where possible, use `tf.compat.v1.make_template` directly.
Args:
template: The `Template` object to register dependencies with.
Yields:
None (when used as a context manager).
"""
name_prefix = template.variable_scope.name
def _trackable_custom_creator(next_creator,
name,
initial_value,
trackable_parent=None,
**kwargs):
"""A variable creation hook which adds Trackable dependencies.
Set for example during a `Template`'s first wrapped function
execution. Ensures that (a) `template` depends on any trackable
objects using their own `capture_dependencies` scope inside this scope which
create variables, and (b) that any variables not in a more deeply nested
scope are added as dependencies directly.
The `trackable_parent` argument is passed between custom creators but
ignored when the variable object itself is created. This argument indicates
(if not `None`) that a more deeply nested scope has already added the
variable as a dependency, and that parent scopes should add a dependency on
that object rather than on the variable directly.
Args:
next_creator: See `variable_scope.variable_creator_scope`; the next
creator in the chain.
name: The (full, scope-influenced) name of the variable. The `name_prefix`
itself is stripped for the purposes of object-based dependency tracking,
but scopes opened within this scope are respected.
initial_value: See `variable_scope.variable_creator_scope`. Taken
explicitly so the argument can be re-named and used with
`Trackable._add_variable_with_custom_getter`.
trackable_parent: If not None, a more deeply nested trackable object and
its name prefix which were passed to `capture_dependencies` to add a
dependency on (rather than depending on the variable directly).
**kwargs: Passed through to the next creator.
Returns:
The output of `next_creator`: the fetched/created variable object.
"""
def _call_next_creator_renaming_initializer(initializer, **inner_kwargs):
inner_kwargs.pop("name") # Ignored; this is the scope-stripped name which
# we don't want to propagate.
return next_creator(initial_value=initializer, name=name, **inner_kwargs)
if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not trackable_parent:
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
initializer=initial_value,
name=scope_stripped_name,
getter=_call_next_creator_renaming_initializer,
# Disable error checking for Trackable. Exceptions are instead
# raised if necessary when the object-based saver tries to
# save/restore the object.
overwrite=True,
trackable_parent=(template, name_prefix),
**kwargs)
else:
parent_object, parent_name_prefix = trackable_parent
template._track_trackable( # pylint: disable=protected-access
parent_object,
name=parent_name_prefix[len(name_prefix) + 1:],
overwrite=True)
return next_creator(
name=name,
initial_value=initial_value,
trackable_parent=(template, name_prefix),
**kwargs)
with variable_scope.variable_creator_scope(_trackable_custom_creator):
yield
class _LoadStatus:
"""Abstract base for load status callbacks."""
@abc.abstractmethod
def assert_consumed(self):
"""Raises an exception unless a non-trivial restoration has completed."""
pass
@abc.abstractmethod
def assert_existing_objects_matched(self):
"""Raises an exception unless existing Python objects have been matched."""
pass
@abc.abstractmethod
def assert_nontrivial_match(self):
"""Raises an exception if only the root object matched."""
pass
@abc.abstractmethod
def run_restore_ops(self, session=None):
"""Runs restore ops from the checkpoint. Requires a valid checkpoint."""
pass
@abc.abstractmethod
def initialize_or_restore(self, session=None):
"""Runs restore ops from the checkpoint, or initializes variables."""
pass
def expect_partial(self):
"""Silence warnings about incomplete checkpoint restores."""
return self
@tf_export("__internal__.tracking.streaming_restore", v1=[])
def streaming_restore(status, session=None):
"""When graph building, runs restore ops as soon as they come in.
Args:
status: A _LoadStatus objects from an object-based saver's restore().
Streaming restore from name-based checkpoints is not currently supported.
session: A session to run new restore ops in.
"""
if context.executing_eagerly():
# Streaming restore is the default/only behavior when executing eagerly.
return
if session is None:
session = get_session()
if isinstance(status, NameBasedSaverStatus):
raise NotImplementedError(
"Streaming restore not supported from name-based checkpoints when "
"graph building. File a feature request if this limitation bothers "
"you. As a workaround, consider either using tf.train.Checkpoint to "
"load name-based checkpoints or enabling eager execution.")
status.run_restore_ops(session=session)
# pylint: disable=protected-access
status._checkpoint.new_restore_ops_callback = (
lambda ops: session.run(ops, feed_dict=status._feed_dict))
# pylint: enable=protected-access
def _objects_with_attributes(full_list):
"""Filters out objects with no direct variable dependencies for assertions."""
return [
o for o in full_list
if saveable_object_util.saveable_objects_from_trackable(o)
]
class CheckpointLoadStatus(_LoadStatus):
"""Checks the status of checkpoint loading and manages restore ops.
Returned from `Saver.restore`. Since `restore` may defer the loading of values
in the checkpoint which don't yet have corresponding Python objects,
`CheckpointLoadStatus` provides a callback to verify that checkpoint loading
is complete (`assert_consumed`).
When graph building, `restore` does not run restore ops itself since their
creation may be deferred. The `run_restore_ops` method must be called once all
Python objects with values to restore have been created and added to the
dependency graph (this does not necessarily have to be the whole checkpoint;
calling `run_restore_ops` while `assert_consumed` fails is supported and will
partially restore the checkpoint).
See `Saver.restore` for usage examples.
"""
def __init__(self, checkpoint, feed_dict, graph_view, options):
self._checkpoint = checkpoint
self._feed_dict = feed_dict
self._object_graph_view = graph_view
# Keep a reference to the root, since object_graph_view might only have a
# weakref.
self._root = graph_view.root
# CheckpointOptions used for restoring
self._options = options
def assert_consumed(self):
"""Asserts that all objects in the checkpoint have been created/matched.
Returns:
`self` for chaining.
Raises:
AssertionError: If there are any Python objects in the dependency graph
which have not been restored from this checkpoint or a later `restore`,
or if there are any checkpointed values which have not been matched to
Python objects.
"""
pretty_printer = ObjectGraphProtoPrettyPrinter(
self._checkpoint.object_graph_proto)
self.assert_existing_objects_matched()
ignore_node_ids = []
if self._options.experimental_skip_slot_variables:
for node in self._checkpoint.object_graph_proto.nodes:
for sv in node.slot_variables:
ignore_node_ids.append(sv.slot_variable_node_id)
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
if not node.attributes:
# Only raise exceptions for the nodes with attributes themselves. Either
# they're ultimately not important, or they have a child with an
# attribute.
continue
if node_id in ignore_node_ids:
continue
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
if trackable is None:
raise AssertionError(
"Unresolved object in checkpoint "
f"{pretty_printer.node_names[node_id]}: {node}")
if (
not self._options.experimental_skip_slot_variables
and self._checkpoint.slot_restorations
):
# Sanity check; this collection should be clear if everything has been
# restored.
raise AssertionError(
f"Unresolved slot restorations: {self._checkpoint.slot_restorations}")
if self._checkpoint.unused_attributes:
unused_attribute_messages = []
for node_id, attribute in self._checkpoint.unused_attributes.items():
obj = self._checkpoint.object_by_proto_id[node_id]
unused_attribute_messages.append(
f"{pretty_printer.node_names[node_id]} ({obj}): {attribute}")
joined_attribute_messages = "\n".join(unused_attribute_messages)
raise AssertionError(
"Unused attributes in these objects (the attributes exist in the "
f"checkpoint but were not restored):\n{joined_attribute_messages}")
return self
def assert_existing_objects_matched(self):
"""Asserts that trackable Python objects have been matched.
Note that this is a weaker assertion than `assert_consumed`. It will only
fail for existing Python objects which are (transitive) dependencies of the
root object and which do not have an entry in the checkpoint.
It will not fail, for example, if a `tf.keras.Layer` object has not yet been
built and so has not created any `tf.Variable` objects.
Returns:
`self` for chaining.
Raises:
AssertionError: If a Python object exists in the transitive dependencies
of the root object but does not have a value in the checkpoint.
"""
for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes):
trackable = self._checkpoint.object_by_proto_id.get(node_id, None)
if (trackable is not None and
trackable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
raise AssertionError(
f"Object {node} not assigned a value from checkpoint.")
for trackable_object in util.list_objects(
self._object_graph_view, self._options.experimental_skip_slot_variables
):
# Remove data structures that do not contain any variables from
# restoration checks.
if isinstance(
trackable_object, data_structures.TrackableDataStructure
) and not trackable_object._trackable_children( # pylint: disable=protected-access
save_type=base.SaveType.CHECKPOINT
):
continue
self._checkpoint.all_python_objects.add(trackable_object)
unused_python_objects = (
object_identity.ObjectIdentitySet(
_objects_with_attributes(
self._checkpoint.all_python_objects)) -
object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
num_unused_python_objects = len(list(unused_python_objects))
# Display max number of 10 variables in error message.
num_variables_to_show = min(10, num_unused_python_objects)
raise AssertionError(
f"Found {num_unused_python_objects} Python objects that were "
"not bound to checkpointed values, likely due to changes in the "
f"Python program. Showing {num_variables_to_show} of "
f"{num_unused_python_objects} unmatched objects: "
f"{list(unused_python_objects)[:num_variables_to_show]}")
return self
def assert_nontrivial_match(self):
"""Raises an exception if only the root object matched."""
for trackable_object in util.list_objects(
self._object_graph_view, self._options.experimental_skip_slot_variables
):
self._checkpoint.all_python_objects.add(trackable_object)
if len(self._checkpoint.object_by_proto_id) <= 1:
unused_python_objects = object_identity.ObjectIdentitySet(
_objects_with_attributes(self._checkpoint.all_python_objects)
) - object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values()
)
if unused_python_objects:
raise AssertionError(
"Nothing except the root object matched a checkpointed value. "
"Typically this means that the checkpoint does not match the "
"Python program. The following objects have no matching "
f"checkpointed value: {list(unused_python_objects)}"
)
else:
raise AssertionError(
"Nothing to load. No dependencies have been added to "
f"{self._object_graph_view.root} yet."
)
return self
def run_restore_ops(self, session=None):
"""Run operations to restore objects in the dependency graph."""
if context.executing_eagerly():
return # Run eagerly
if session is None:
session = get_session()
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
def initialize_or_restore(self, session=None):
"""Run operations to initialize or restore objects in the dependency graph.
Any objects in the dependency graph which have initializers but are not in
the checkpoint will have those initializers run, unless those variables are
being restored by a later call to `tf.train.Checkpoint.restore()`.
This method has a sibling in `InitializationOnlyStatus` which instead
initializes variables. That type is returned if no checkpoint is specified
in `Saver.restore`.
Args:
session: The session to run init/restore ops in. If `None`, uses the
default session.
"""
if context.executing_eagerly():
return # Initialization and restoration ops are run eagerly
if session is None:
session = get_session()
all_objects = util.list_objects(self._object_graph_view)
already_initialized_objects = object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values())
initializers_for_non_restored_variables = [
c.initializer for c in all_objects
if hasattr(c, "initializer")
and c not in already_initialized_objects
and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
< self._checkpoint.restore_uid)
]
self.run_restore_ops(session=session)
session.run(initializers_for_non_restored_variables)
def expect_partial(self):
"""Silence warnings about incomplete checkpoint restores."""
self._checkpoint.expect_partial = True
return self
class InitializationOnlyStatus(_LoadStatus):
"""Returned from `Saver.restore` when no checkpoint has been specified.
Objects of this type have the same `assert_consumed` method as
`CheckpointLoadStatus`, but it always fails. However,
`initialize_or_restore` works on objects of both types, and will
initialize variables in `InitializationOnlyStatus` objects or restore them
otherwise.
"""
def __init__(self, object_graph_view, restore_uid):
self._restore_uid = restore_uid
self._object_graph_view = object_graph_view
# Keep a reference to the root, since graph_view might only have a weakref.
self._root = object_graph_view.root
def assert_consumed(self):