Skip to content

Commit

Permalink
Correctly set the experimental_io_device when restoring variable from…
Browse files Browse the repository at this point in the history
… a checkpoint.

PiperOrigin-RevId: 320222381
Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227
  • Loading branch information
kenfranko committed Jul 13, 2020
1 parent 14b2d68 commit b8694e3
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions tensorflow/python/training/tracking/base.py
Expand Up @@ -293,9 +293,10 @@ def value_tensors(self):
checkpoint_key = serialized_tensor.checkpoint_key
dtype = self._checkpoint.dtype_map[checkpoint_key]
base_type = dtype.base_dtype
io_device = self._checkpoint.options.experimental_io_device or "cpu:0"
with ops.init_scope():
with ops.device("/cpu:0"):
# Run the restore itself on the CPU.
with ops.device(io_device):
# Run the restore itself on the io_device(CPU or specified).
value, = io_ops.restore_v2(
prefix=self._checkpoint.save_path_tensor,
tensor_names=[checkpoint_key],
Expand Down

0 comments on commit b8694e3

Please sign in to comment.