-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsaver.py
1854 lines (1604 loc) · 74.8 KB
/
saver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Save and restore variables.
Symbols in this file are deprecated. See replacements in
tensorflow/python/training/trackable and tensorflow/python/training/saving.
"""
import collections
import glob
import os.path
import threading
import time
import numpy as np
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf import trackable_object_graph_pb2
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import errors
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.trackable import base as trackable
from tensorflow.python.training import py_checkpoint_reader
from tensorflow.python.training import training_util
from tensorflow.python.training.saving import saveable_object
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
# TODO(allenl): Remove these aliases once all users are migrated off.
get_checkpoint_state = checkpoint_management.get_checkpoint_state
update_checkpoint_state = checkpoint_management.update_checkpoint_state
generate_checkpoint_state_proto = (
checkpoint_management.generate_checkpoint_state_proto)
latest_checkpoint = checkpoint_management.latest_checkpoint
checkpoint_exists = checkpoint_management.checkpoint_exists
get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
remove_checkpoint = checkpoint_management.remove_checkpoint
# Captures the timestamp of the first Saver object instantiation or end of a
# save operation. Can be accessed by multiple Saver instances.
_END_TIME_OF_LAST_WRITE = None
_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
# API label for cell name used in checkpoint metrics.
_SAVER_LABEL = "saver_v1"
def _get_duration_microseconds(start_time_seconds, end_time_seconds):
if end_time_seconds < start_time_seconds:
# Avoid returning negative value in case of clock skew.
return 0
return round((end_time_seconds - start_time_seconds) * 1000000)
def _get_checkpoint_size(prefix):
"""Calculates filesize of checkpoint based on prefix."""
size = 0
# Gather all files beginning with prefix (.index plus sharded data files).
files = glob.glob("{}*".format(prefix))
for file in files:
# Use TensorFlow's C++ FileSystem API.
size += metrics.CalculateFileSize(file)
return size
class BaseSaverBuilder:
"""Base class for Savers.
Can be extended to create different Ops.
"""
SaveSpec = saveable_object.SaveSpec
SaveableObject = saveable_object.SaveableObject
# Aliases for code which was moved but still has lots of users.
VariableSaveable = saveable_object_util.ReferenceVariableSaveable
ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
def __init__(self, write_version=saver_pb2.SaverDef.V2):
self._write_version = write_version
def save_op(self, filename_tensor, saveables):
"""Create an Op to save 'saveables'.
This is intended to be overridden by subclasses that want to generate
different Ops.
Args:
filename_tensor: String Tensor.
saveables: A list of BaseSaverBuilder.SaveableObject objects.
Returns:
An Operation that save the variables.
Raises:
RuntimeError: (implementation detail) if "self._write_version" is an
unexpected value.
"""
# pylint: disable=protected-access
tensor_names = []
tensors = []
tensor_slices = []
for saveable in saveables:
for spec in saveable.specs:
tensor_names.append(spec.name)
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)
if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
"""Restore all tensors contained in saveables.
By default, this issues separate calls to `restore_op` for each saveable.
Subclasses may override to load multiple saveables in a single call.
Args:
filename_tensor: String Tensor.
saveables: List of BaseSaverBuilder.SaveableObject objects.
preferred_shard: Int. Shard to open first when loading a sharded file.
restore_sequentially: Unused. Bool. If true, each restore is sequential.
Returns:
A list of Tensors resulting from reading 'saveable' from
'filename'.
"""
del restore_sequentially
all_tensors = []
for saveable in saveables:
if saveable.device:
device = saveable_object_util.set_cpu0(saveable.device)
else:
device = None
with ops.device(device):
all_tensors.extend(
self.restore_op(filename_tensor, saveable, preferred_shard))
return all_tensors
# pylint: disable=unused-argument
def restore_op(self, filename_tensor, saveable, preferred_shard):
"""Create ops to restore 'saveable'.
This is intended to be overridden by subclasses that want to generate
different Ops.
Args:
filename_tensor: String Tensor.
saveable: A BaseSaverBuilder.SaveableObject object.
preferred_shard: Int. Shard to open first when loading a sharded file.
Returns:
A list of Tensors resulting from reading 'saveable' from
'filename'.
"""
# pylint: disable=protected-access
tensors = []
for spec in saveable.specs:
tensors.append(
io_ops.restore_v2(filename_tensor, [spec.name], [spec.slice_spec],
[spec.dtype])[0])
return tensors
# pylint: enable=unused-argument
def sharded_filename(self, filename_tensor, shard, num_shards):
"""Append sharding information to a filename.
Args:
filename_tensor: A string tensor.
shard: Integer. The shard for the filename.
num_shards: An int Tensor for the number of shards.
Returns:
A string tensor.
"""
return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
def _AddSaveOps(self, filename_tensor, saveables):
"""Add ops to save variables that are on the same shard.
Args:
filename_tensor: String Tensor.
saveables: A list of SaveableObject objects.
Returns:
A tensor with the filename used to save.
"""
save = self.save_op(filename_tensor, saveables)
return control_flow_ops.with_dependencies([save], filename_tensor)
def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
"""Add ops to save the params per shard, for the V2 format.
Note that the sharded save procedure for the V2 format is different from
V1: there is a special "merge" step that merges the small metadata produced
from each device.
Args:
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A FILENAME*,
but as a prefix of a V2 checkpoint;
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
returned by _GroupByDevices().
Returns:
An op to save the variables, which, when evaluated, returns the prefix
"<user-fed prefix>" only and does not include the sharded spec suffix.
"""
# IMPLEMENTATION DETAILS: most clients should skip.
#
# Suffix for any well-formed "checkpoint_prefix", when sharded.
# Transformations:
# * Users pass in "save_path" in save() and restore(). Say "myckpt".
# * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
# * If checkpoint_prefix is a S3 bucket path ".part" is appended to it
# * Otherwise _temp/part is appended which is normalized relative to the OS
# Example:
# During runtime, a temporary directory is first created, which contains
# files
#
# <train dir>/myckpt_temp/
# part-?????-of-?????{.index, .data-00000-of-00001}
#
# Before .save() finishes, they will be (hopefully, atomically) renamed to
#
# <train dir>/
# myckpt{.index, .data-?????-of-?????}
#
# Filesystems with eventual consistency (such as S3), don't need a
# temporary location. Using a temporary directory in those cases might
# cause situations where files are not available during copy.
#
# Users only need to interact with the user-specified prefix, which is
# "<train dir>/myckpt" in this case. Save() and Restore() work with the
# prefix directly, instead of any physical pathname. (On failure and
# subsequent restore, an outdated and orphaned temporary directory can be
# safely removed.)
with ops.device("CPU"):
_SHARDED_SUFFIX = array_ops.where(
string_ops.regex_full_match(checkpoint_prefix, "^s3://.*"),
constant_op.constant(".part"),
constant_op.constant(os.path.normpath("_temp/part")))
tmp_checkpoint_prefix = string_ops.string_join(
[checkpoint_prefix, _SHARDED_SUFFIX])
num_shards = len(per_device)
sharded_saves = []
sharded_prefixes = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
last_device = None
for shard, (device, saveables) in enumerate(per_device):
last_device = device
with ops.device(saveable_object_util.set_cpu0(device)):
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
sharded_prefixes.append(sharded_filename)
sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
with ops.control_dependencies([x.op for x in sharded_saves]):
# Co-locates the merge step with the last device.
with ops.device(saveable_object_util.set_cpu0(last_device)):
# V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
merge_step = gen_io_ops.merge_v2_checkpoints(
sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
with ops.control_dependencies([merge_step]):
# Returns the prefix "<user-fed prefix>" only. DOES NOT include the
# sharded spec suffix.
return array_ops.identity(checkpoint_prefix)
def _AddShardedSaveOps(self, filename_tensor, per_device):
"""Add ops to save the params per shard.
Args:
filename_tensor: a scalar String Tensor.
per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
returned by _GroupByDevices().
Returns:
An op to save the variables.
"""
if self._write_version == saver_pb2.SaverDef.V2:
return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
num_shards = len(per_device)
sharded_saves = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
for shard, (device, saveables) in enumerate(per_device):
with ops.device(device):
sharded_filename = self.sharded_filename(filename_tensor, shard,
num_shards_tensor)
sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
# Return the sharded name for the save path.
with ops.control_dependencies([x.op for x in sharded_saves]):
return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
def _AddRestoreOps(self,
filename_tensor,
saveables,
restore_sequentially,
reshape,
preferred_shard=-1,
name="restore_all"):
"""Add operations to restore saveables.
Args:
filename_tensor: Tensor for the path of the file to load.
saveables: A list of SaveableObject objects.
restore_sequentially: True if we want to restore variables sequentially
within a shard.
reshape: True if we want to reshape loaded tensors to the shape of the
corresponding variable.
preferred_shard: Shard to open first when loading a sharded file.
name: Name for the returned op.
Returns:
An Operation that restores the variables.
"""
all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
restore_sequentially)
assign_ops = []
idx = 0
# Load and optionally reshape on the CPU, as string tensors are not
# available on the GPU.
# TODO(touts): Re-enable restore on GPU when we can support annotating
# string tensors as "HostMemory" inputs.
for saveable in saveables:
shapes = None
if reshape:
# Compute the shapes, let the restore op decide if and how to do
# the reshape.
shapes = []
for spec in saveable.specs:
v = spec.tensor
shape = v.get_shape()
if not shape.is_fully_defined():
shape = array_ops.shape(v)
shapes.append(shape)
saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
idx += len(saveable.specs)
assign_ops.append(saveable.restore(saveable_tensors, shapes))
# Create a Noop that has control dependencies from all the updates.
return control_flow_ops.group(*assign_ops, name=name)
def _AddShardedRestoreOps(self, filename_tensor, per_device,
restore_sequentially, reshape):
"""Add Ops to restore variables from multiple devices.
Args:
filename_tensor: Tensor for the path of the file to load.
per_device: A list of (device, SaveableObject) pairs, as returned by
_GroupByDevices().
restore_sequentially: True if we want to restore variables sequentially
within a shard.
reshape: True if we want to reshape loaded tensors to the shape of the
corresponding variable.
Returns:
An Operation that restores the variables.
"""
sharded_restores = []
for shard, (device, saveables) in enumerate(per_device):
with ops.device(device):
sharded_restores.append(
self._AddRestoreOps(
filename_tensor,
saveables,
restore_sequentially,
reshape,
preferred_shard=shard,
name="restore_shard"))
return control_flow_ops.group(*sharded_restores, name="restore_all")
def _GroupByDevices(self, saveables):
"""Group Variable tensor slices per device.
TODO(touts): Make sure that all the devices found are on different
job/replica/task/cpu|gpu. It would be bad if 2 were on the same device.
It can happen if the devices are unspecified.
Args:
saveables: A list of BaseSaverBuilder.SaveableObject objects.
Returns:
A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
The list is sorted by ascending device_name.
Raises:
ValueError: If the tensors of a saveable are on different devices.
"""
per_device = collections.defaultdict(lambda: [])
for saveable in saveables:
canonical_device = set(
pydev.canonical_name(spec.device) for spec in saveable.specs)
if len(canonical_device) != 1:
raise ValueError("All tensors of a saveable object must be "
"on the same device: %s" % saveable.name)
per_device[canonical_device.pop()].append(saveable)
return sorted(per_device.items(), key=lambda t: t[0])
def build(self,
names_to_saveables,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
filename="model"):
"""Builds save/restore graph nodes or runs save/restore in eager mode.
Args:
names_to_saveables: A dictionary mapping name to a Variable or
SaveableObject. Each name will be associated with the corresponding
variable in the checkpoint.
reshape: If True, allow restoring parameters from a checkpoint that where
the parameters have a different shape. This is only needed when you try
to restore from a Dist-Belief checkpoint, and only some times.
sharded: If True, shard the checkpoints, one per device that has Variable
nodes.
max_to_keep: Maximum number of checkpoints to keep. As new checkpoints
are created, old ones are deleted. If None or 0, no checkpoints are
deleted from the filesystem but only the last one is kept in the
`checkpoint` file. Presently the number is only roughly enforced. For
example in case of restarts more than max_to_keep checkpoints may be
kept.
keep_checkpoint_every_n_hours: How often checkpoints should be kept.
Defaults to 10,000 hours.
name: String. Optional name to use as a prefix when adding operations.
restore_sequentially: A Bool, which if true, causes restore of different
variables to happen sequentially within each device.
filename: If known at graph construction time, filename used for variable
loading/saving. If None, then the default name "model" will be used.
Returns:
A SaverDef proto.
Raises:
TypeError: If 'names_to_saveables' is not a dictionary mapping string
keys to variable Tensors.
ValueError: If any of the keys or values in 'names_to_saveables' is not
unique.
"""
return self._build_internal(
names_to_saveables=names_to_saveables,
reshape=reshape,
sharded=sharded,
max_to_keep=max_to_keep,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
name=name,
restore_sequentially=restore_sequentially,
filename=filename)
def _build_internal(self,
names_to_saveables,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
filename="model",
build_save=True,
build_restore=True):
"""build() with option to only perform save and restore."""
if not context.executing_eagerly() and (not build_save or
not build_restore):
raise ValueError("save and restore operations need to be built together "
" when eager execution is not enabled.")
if not isinstance(names_to_saveables, dict):
names_to_saveables = saveable_object_util.op_list_to_dict(
names_to_saveables)
saveables = saveable_object_util.validate_and_slice_inputs(
names_to_saveables)
if max_to_keep is None:
max_to_keep = 0
with ops.name_scope(name, "save",
[saveable.op for saveable in saveables]) as name:
# Add a placeholder string tensor for the filename.
filename_tensor = array_ops.placeholder_with_default(
filename or "model", shape=(), name="filename")
# Keep the name "Const" for backwards compatibility.
filename_tensor = array_ops.placeholder_with_default(
filename_tensor, shape=(), name="Const")
# Add the save ops.
if sharded:
per_device = self._GroupByDevices(saveables)
if build_save:
save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
if build_restore:
restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
restore_sequentially, reshape)
else:
if build_save:
save_tensor = self._AddSaveOps(filename_tensor, saveables)
if build_restore:
restore_op = self._AddRestoreOps(filename_tensor, saveables,
restore_sequentially, reshape)
# In the following use case, it's possible to have restore_ops be called
# something else:
# - Build inference graph and export a meta_graph.
# - Import the inference meta_graph
# - Extend the inference graph to a train graph.
# - Export a new meta_graph.
# Now the second restore_op will be called "restore_all_1".
# As such, comment out the assert for now until we know whether supporting
# such usage model makes sense.
#
# assert restore_op.name.endswith("restore_all"), restore_op.name
if context.executing_eagerly():
# Store the tensor values to the tensor_names.
save_tensor_name = save_tensor.numpy() if build_save else ""
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.numpy(),
save_tensor_name=save_tensor_name,
restore_op_name="",
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
else:
graph = ops.get_default_graph()
# Do some sanity checking on collections containing
# PartitionedVariables. If a saved collection has a PartitionedVariable,
# the GraphDef needs to include concat ops to get the value (or there'll
# be a lookup error on load).
check_collection_list = graph.get_all_collection_keys()
for collection_type in check_collection_list:
for element in graph.get_collection(collection_type):
if isinstance(element, variables.PartitionedVariable):
try:
graph.get_operation_by_name(element.name)
except KeyError:
# Create a concat op for this PartitionedVariable. The user may
# not need it, but we'll try looking it up on MetaGraph restore
# since it's in a collection.
element.as_tensor()
return saver_pb2.SaverDef(
filename_tensor_name=filename_tensor.name,
save_tensor_name=save_tensor.name,
restore_op_name=restore_op.name,
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
class BulkSaverBuilder(BaseSaverBuilder):
"""SaverBuilder with support for bulk restoring multiple saveables."""
def bulk_restore(self, filename_tensor, saveables, preferred_shard,
restore_sequentially):
# Ignored: bulk restore is internally sequential.
del restore_sequentially
restore_specs = []
for saveable in saveables:
for spec in saveable.specs:
restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
names, slices, dtypes = zip(*restore_specs)
# Load all tensors onto CPU 0 for compatibility with existing code.
with ops.device("cpu:0"):
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
def _get_saver_or_default():
"""Returns the saver from SAVERS collection, or creates a default one.
This method is used by other members of the training module, such as
`Scaffold`, or `CheckpointSaverHook`.
Returns:
`Saver`.
Raises:
RuntimeError: If the SAVERS collection already has more than one items.
"""
collection_key = ops.GraphKeys.SAVERS
savers = ops.get_collection(collection_key)
if savers:
if len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
return savers[0]
saver = Saver(sharded=True, allow_empty=True)
if saver is not None:
ops.add_to_collection(collection_key, saver)
return saver
@tf_export(v1=["train.Saver"])
class Saver:
# pylint: disable=line-too-long
"""Saves and restores variables.
@compatibility(TF2)
`tf.compat.v1.train.Saver` is not supported for saving and restoring
checkpoints in TF2. Please switch to `tf.train.Checkpoint` or
`tf.keras.Model.save_weights`, which perform a more robust [object-based
saving](https://www.tensorflow.org/guide/checkpoint#loading_mechanics).
### How to Rewrite Checkpoints
Please rewrite your checkpoints immediately using the object-based checkpoint
APIs.
You can load a name-based checkpoint written by `tf.compat.v1.train.Saver`
using `tf.train.Checkpoint.restore` or `tf.keras.Model.load_weights`. However,
you may have to change the names of the variables in your model to match the
variable names in the name-based checkpoint, which can be viewed with
`tf.train.list_variables(path)`.
Another option is to create an `assignment_map` that maps the name of the
variables in the name-based checkpoint to the variables in your model, eg:
```
{
'sequential/dense/bias': model.variables[0],
'sequential/dense/kernel': model.variables[1]
}
```
and use `tf.compat.v1.train.init_from_checkpoint(path, assignment_map)` to
restore the name-based checkpoint.
After restoring, re-encode your checkpoint
using `tf.train.Checkpoint.save` or `tf.keras.Model.save_weights`.
See the [Checkpoint compatibility](
https://www.tensorflow.org/guide/migrate#checkpoint_compatibility)
section of the migration guide for more details.
### Checkpoint Management in TF2
Use `tf.train.CheckpointManager` to manage checkpoints in TF2.
`tf.train.CheckpointManager` offers equivalent `keep_checkpoint_every_n_hours`
and `max_to_keep` parameters.
To recover the latest checkpoint,
```
checkpoint = tf.train.Checkpoint(model)
manager = tf.train.CheckpointManager(checkpoint)
status = checkpoint.restore(manager.latest_checkpoint)
```
`tf.train.CheckpointManager` also writes a [`CheckpointState` proto]
(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/checkpoint_state.proto)
which contains the timestamp when each checkpoint was created.
### Writing `MetaGraphDef`s in TF2
To replace, `tf.compat.v1.train.Saver.save(write_meta_graph=True)`, use
`tf.saved_model.save` to write the `MetaGraphDef` (which is contained in
`saved_model.pb`).
@end_compatibility
See [Variables](https://tensorflow.org/guide/variables)
for an overview of variables, saving and restoring.
The `Saver` class adds ops to save and restore variables to and from
*checkpoints*. It also provides convenience methods to run these ops.
Checkpoints are binary files in a proprietary format which map variable names
to tensor values. The best way to examine the contents of a checkpoint is to
load it using a `Saver`.
Savers can automatically number checkpoint filenames with a provided counter.
This lets you keep multiple checkpoints at different steps while training a
model. For example you can number the checkpoint filenames with the training
step number. To avoid filling up disks, savers manage checkpoint files
automatically. For example, they can keep only the N most recent files, or
one checkpoint for every N hours of training.
You number checkpoint filenames by passing a value to the optional
`global_step` argument to `save()`:
```python
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
```
Additionally, optional arguments to the `Saver()` constructor let you control
the proliferation of checkpoint files on disk:
* `max_to_keep` indicates the maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0,
no checkpoints are deleted from the filesystem but only the last one is
kept in the `checkpoint` file. Defaults to 5 (that is, the 5 most recent
checkpoint files are kept.)
* `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
`max_to_keep` checkpoint files, you might want to keep one checkpoint file
for every N hours of training. This can be useful if you want to later
analyze how a model progressed during a long training session. For
example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
one checkpoint file for every 2 hours of training. The default value of
10,000 hours effectively disables the feature.
Note that you still have to call the `save()` method to save the model.
Passing these arguments to the constructor will not save variables
automatically for you.
A training program that saves regularly looks like:
```python
...
# Create a saver.
saver = tf.compat.v1.train.Saver(...variables...)
# Launch the graph and train, saving the model every 1,000 steps.
sess = tf.compat.v1.Session()
for step in range(1000000):
sess.run(..training_op..)
if step % 1000 == 0:
# Append the step number to the checkpoint name:
saver.save(sess, 'my-model', global_step=step)
```
In addition to checkpoint files, savers keep a protocol buffer on disk with
the list of recent checkpoints. This is used to manage numbered checkpoint
files and by `latest_checkpoint()`, which makes it easy to discover the path
to the most recent checkpoint. That protocol buffer is stored in a file named
'checkpoint' next to the checkpoint files.
If you create several savers, you can specify a different filename for the
protocol buffer file in the call to `save()`.
"""
# pylint: enable=line-too-long
def __init__(self,
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None):
"""Creates a `Saver`.
The constructor adds ops to save and restore variables.
`var_list` specifies the variables that will be saved and restored. It can
be passed as a `dict` or a list:
* A `dict` of names to variables: The keys are the names that will be
used to save or restore the variables in the checkpoint files.
* A list of variables: The variables will be keyed with their op name in
the checkpoint files.
For example:
```python
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2')
# Pass the variables as a dict:
saver = tf.compat.v1.train.Saver({'v1': v1, 'v2': v2})
# Or pass them as a list.
saver = tf.compat.v1.train.Saver([v1, v2])
# Passing a list is equivalent to passing a dict with the variable op names
# as keys:
saver = tf.compat.v1.train.Saver({v.op.name: v for v in [v1, v2]})
```
Note: the newer `AutoTrackable` API is not supported by `Saver`. In this
case, the `tf.train.Checkpoint` class should be used.
The optional `reshape` argument, if `True`, allows restoring a variable from
a save file where the variable had a different shape, but the same number
of elements and type. This is useful if you have reshaped a variable and
want to reload it from an older checkpoint.
The optional `sharded` argument, if `True`, instructs the saver to shard
checkpoints per device.
Args:
var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
names to `SaveableObject`s. If `None`, defaults to the list of all
saveable objects.
reshape: If `True`, allows restoring parameters from a checkpoint where
the variables have a different shape.
sharded: If `True`, shard the checkpoints, one per device.
max_to_keep: Maximum number of recent checkpoints to keep. Defaults to 5.
keep_checkpoint_every_n_hours: How often to keep checkpoints. Defaults to
10,000 hours.
name: String. Optional name to use as a prefix when adding operations.
restore_sequentially: A `Bool`, which if true, causes restore of different
variables to happen sequentially within each device. This can lower
memory usage when restoring very large models.
saver_def: Optional `SaverDef` proto to use instead of running the
builder. This is only useful for specialty code that wants to recreate a
`Saver` object for a previously built `Graph` that had a `Saver`. The
`saver_def` proto should be the one returned by the `as_saver_def()`
call of the `Saver` that was created for that `Graph`.
builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
Defaults to `BulkSaverBuilder()`.
defer_build: If `True`, defer adding the save and restore ops to the
`build()` call. In that case `build()` should be called before
finalizing the graph or using the saver.
allow_empty: If `False` (default) raise an error if there are no variables
in the graph. Otherwise, construct the saver anyway and make it a no-op.
write_version: controls what format to use when saving checkpoints. It
also affects certain filepath matching logic. The V2 format is the
recommended choice: it is much more optimized than V1 in terms of memory
required and latency incurred during restore. Regardless of this flag,
the Saver is able to restore from both V2 and V1 checkpoints.
pad_step_number: if True, pads the global step number in the checkpoint
filepaths to some fixed width (8 by default). This is turned off by
default.
save_relative_paths: If `True`, will write relative paths to the
checkpoint state file. This is needed if the user wants to copy the
checkpoint directory and reload from the copied directory.
filename: If known at graph construction time, filename used for variable
loading/saving.
Raises:
TypeError: If `var_list` is invalid.
ValueError: If any of the keys or values in `var_list` are not unique.
RuntimeError: If eager execution is enabled and`var_list` does not specify
a list of variables to save.
@compatibility(eager)
When eager execution is enabled, `var_list` must specify a `list` or `dict`
of variables to save. Otherwise, a `RuntimeError` will be raised.
Although Saver works in some cases when executing eagerly, it is
fragile. Please switch to `tf.train.Checkpoint` or
`tf.keras.Model.save_weights`, which perform a more robust object-based
saving. These APIs will load checkpoints written by `Saver`.
@end_compatibility
"""
global _END_TIME_OF_LAST_WRITE
with _END_TIME_OF_LAST_WRITE_LOCK:
if _END_TIME_OF_LAST_WRITE is None:
_END_TIME_OF_LAST_WRITE = time.time()
if defer_build and var_list:
raise ValueError(
"If `var_list` is provided then build cannot be deferred. "
"Either set defer_build=False or var_list=None.")
if context.executing_eagerly():
logging.warning(
"Saver is deprecated, please switch to tf.train.Checkpoint or "
"tf.keras.Model.save_weights for training checkpoints. When "
"executing eagerly variables do not necessarily have unique names, "
"and so the variable.name-based lookups Saver performs are "
"error-prone.")
if var_list is None:
raise RuntimeError(
"When eager execution is enabled, `var_list` must specify a list "
"or dict of variables to save")
self._var_list = var_list
self._reshape = reshape
self._sharded = sharded
self._max_to_keep = max_to_keep
self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
self._name = name
self._restore_sequentially = restore_sequentially
self.saver_def = saver_def
self._builder = builder
self._is_built = False
self._allow_empty = allow_empty
self._is_empty = None
self._write_version = write_version
self._pad_step_number = pad_step_number
self._filename = filename
self._last_checkpoints = []
self._checkpoints_to_be_deleted = []
if context.executing_eagerly():
self._next_checkpoint_time = (
time.time() + self._keep_checkpoint_every_n_hours * 3600)
elif not defer_build:
self.build()
if self.saver_def:
self._check_saver_def()
self._write_version = self.saver_def.version
self._save_relative_paths = save_relative_paths
# For compatibility with object-based checkpoints, we may build a second
# Saver to read the renamed keys.
self._object_restore_saver = None
def build(self):
if context.executing_eagerly():
raise RuntimeError("Use save/restore instead of build in eager mode.")
self._build(self._filename, build_save=True, build_restore=True)
def _build_eager(self, checkpoint_path, build_save, build_restore):
self._build(
checkpoint_path, build_save=build_save, build_restore=build_restore)
def _build(self, checkpoint_path, build_save, build_restore):
"""Builds saver_def."""
if not context.executing_eagerly():
if self._is_built:
return
self._is_built = True
if not self.saver_def or context.executing_eagerly():
if self._builder is None:
self._builder = BulkSaverBuilder(self._write_version)
if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()
if not self._var_list:
if self._allow_empty:
self._is_empty = True
return
else:
raise ValueError("No variables to save")
self._is_empty = False
self.saver_def = self._builder._build_internal( # pylint: disable=protected-access
self._var_list,
reshape=self._reshape,
sharded=self._sharded,
max_to_keep=self._max_to_keep,
keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
name=self._name,
restore_sequentially=self._restore_sequentially,
filename=checkpoint_path,
build_save=build_save,
build_restore=build_restore)
elif self.saver_def and self._name:
# Since self._name is used as a name_scope by builder(), we are
# overloading the use of this field to represent the "import_scope" as
# well.
self.saver_def.filename_tensor_name = ops.prepend_name_scope(
self.saver_def.filename_tensor_name, self._name)
self.saver_def.save_tensor_name = ops.prepend_name_scope(
self.saver_def.save_tensor_name, self._name)
self.saver_def.restore_op_name = ops.prepend_name_scope(
self.saver_def.restore_op_name, self._name)
self._check_saver_def()
if not context.executing_eagerly():
# Updates next checkpoint time.
# Set in __init__ when executing eagerly.
self._next_checkpoint_time = (