Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "TF Checkpoint V2: add "write_version" arg to tf.train.Saver." #5266

Merged
merged 1 commit into from Oct 29, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion tensorflow/python/BUILD
Expand Up @@ -1537,7 +1537,6 @@ py_library(
":random_ops",
":sparse_ops",
":state_ops",
":string_ops",
":training_ops_gen",
":variable_scope",
":variables",
Expand Down
115 changes: 12 additions & 103 deletions tensorflow/python/training/saver.py
Expand Up @@ -48,7 +48,6 @@
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
Expand Down Expand Up @@ -141,6 +140,8 @@ class BaseSaverBuilder(object):
Can be extended to create different Ops.
"""

_CHECKPOINT_FORMAT_VERSION = saver_pb2.SaverDef.V1

class SaveSpec(object):
"""Class used to describe tensor slices that need to be saved."""

Expand Down Expand Up @@ -210,8 +211,8 @@ def restore(self, restored_tensors, restored_shapes):
validate_shape=restored_shapes is None and
self.op.get_shape().is_fully_defined())

def __init__(self, write_version=saver_pb2.SaverDef.V1):
self._write_version = write_version
def __init__(self):
pass

def save_op(self, filename_tensor, saveables):
"""Create an Op to save 'saveables'.
Expand All @@ -225,10 +226,6 @@ def save_op(self, filename_tensor, saveables):

Returns:
An Operation that save the variables.

Raises:
RuntimeError: (implementation detail) if "self._write_version" is an
unexpected value.
"""
# pylint: disable=protected-access
tensor_names = []
Expand All @@ -240,19 +237,11 @@ def save_op(self, filename_tensor, saveables):
tensors.append(spec.tensor)
tensor_slices.append(spec.slice_spec)

if self._write_version == saver_pb2.SaverDef.V1:
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)
elif self._write_version == saver_pb2.SaverDef.V2:
# "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
# of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
tensors)
else:
raise RuntimeError("Unexpected write_version: " + self._write_version)
return io_ops._save(
filename=filename_tensor,
tensor_names=tensor_names,
tensors=tensors,
tensor_slices=tensor_slices)

def restore_op(self, filename_tensor, saveable, preferred_shard):
"""Create ops to restore 'saveable'.
Expand Down Expand Up @@ -309,76 +298,6 @@ def _AddSaveOps(self, filename_tensor, saveables):
save = self.save_op(filename_tensor, saveables)
return control_flow_ops.with_dependencies([save], filename_tensor)

def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
"""Add ops to save the params per shard, for the V2 format.

Note that the sharded save procedure for the V2 format is different from
V1: there is a special "merge" step that merges the small metadata produced
from each device.

Args:
checkpoint_prefix: scalar String Tensor. Interpreted *NOT AS A
FILENAME*, but as a prefix of a V2 checkpoint;
per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
returned by _GroupByDevices().

Returns:
An op to save the variables, which, when evaluated, returns the prefix
"<user-fed prefix>" only and does not include the sharded spec suffix.
"""
# IMPLEMENTATION DETAILS: most clients should skip.
#
# Suffix for any well-formed "checkpoint_prefix", when sharded.
# Transformations:
# * Users pass in "save_path" in save() and restore(). Say "myckpt".
# * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
#
# Example:
# During runtime, a temporary directory is first created, which contains
# files
#
# <train dir>/myckpt_temp/
# part-?????-of-?????{.index, .data-00000-of-00001}
#
# Before .save() finishes, they will be (hopefully, atomically) renamed to
#
# <train dir>/
# myckpt{.index, .data-?????-of-?????}
#
# Users only need to interact with the user-specified prefix, which is
# "<train dir>/myckpt" in this case. Save() and Restore() work with the
# prefix directly, instead of any physical pathname. (On failure and
# subsequent restore, an outdated and orphaned temporary directory can be
# safely removed.)
_SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
tmp_checkpoint_prefix = string_ops.string_join(
[checkpoint_prefix, _SHARDED_SUFFIX])

num_shards = len(per_device)
sharded_saves = []
sharded_prefixes = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
last_device = None
for shard, (device, saveables) in enumerate(per_device):
last_device = device
with ops.device(device):
sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
num_shards_tensor)
sharded_prefixes.append(sharded_filename)
sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))

with ops.control_dependencies([x.op for x in sharded_saves]):
# Co-locates the merge step with the last device.
with ops.device(last_device):
# V2 format write path consists of a metadata merge step. Once merged,
# attempts to delete the temporary directory, "<user-fed prefix>_temp".
merge_step = gen_io_ops.merge_v2_checkpoints(
sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
with ops.control_dependencies([merge_step]):
# Returns the prefix "<user-fed prefix>" only. DOES NOT include the
# sharded spec suffix.
return array_ops.identity(checkpoint_prefix)

def _AddShardedSaveOps(self, filename_tensor, per_device):
"""Add ops to save the params per shard.

Expand All @@ -390,9 +309,6 @@ def _AddShardedSaveOps(self, filename_tensor, per_device):
Returns:
An op to save the variables.
"""
if self._write_version == saver_pb2.SaverDef.V2:
return self._AddShardedSaveOpsForV2(filename_tensor, per_device)

num_shards = len(per_device)
sharded_saves = []
num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
Expand Down Expand Up @@ -722,7 +638,7 @@ def build(self,
max_to_keep=max_to_keep,
sharded=sharded,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
version=self._write_version)
version=self._CHECKPOINT_FORMAT_VERSION)


def _GetCheckpointFilename(save_dir, latest_filename):
Expand Down Expand Up @@ -980,8 +896,7 @@ def __init__(self,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=saver_pb2.SaverDef.V1):
allow_empty=False):
"""Creates a `Saver`.

The constructor adds ops to save and restore variables.
Expand Down Expand Up @@ -1046,11 +961,6 @@ def __init__(self,
allow_empty: If `False` (default) raise an error if there are no
variables in the graph. Otherwise, construct the saver anyway and make
it a no-op.
write_version: controls what format to use when saving checkpoints. It
also affects certain filepath matching logic. Defaults to V1
currently, and will be switched to the more memory-efficient V2 format
in the future. If set to V2, the Saver is still able to restore from
old V1 checkpoints.

Raises:
TypeError: If `var_list` is invalid.
Expand All @@ -1072,7 +982,6 @@ def __init__(self,
self._is_built = False
self._allow_empty = allow_empty
self._is_empty = None
self._write_version = write_version
if not defer_build:
self.build()
if self.saver_def:
Expand All @@ -1085,7 +994,7 @@ def build(self):
self._is_built = True
if not self.saver_def:
if self._builder is None:
self._builder = BaseSaverBuilder(self._write_version)
self._builder = BaseSaverBuilder()
if self._var_list is None:
# pylint: disable=protected-access
self._var_list = variables._all_saveable_objects()
Expand Down