Skip to content

Commit

Permalink
Change PolicySaver.save_checkpoint() to save files in the 'variables'…
Browse files Browse the repository at this point in the history
… sub-dir.

This makes the layout files for save_checkpoint() the same as for full models,
which is tidier and make it easier to combine will full model files.

Make the code more robust by using constants defined in the saved_model module
for the various components of saved paths. Still use the current values in
tests, because users will likely hardcode these values in their code.

Update all users and tests for save_checkpoint().

PiperOrigin-RevId: 302211945
Change-Id: I45ce03c5feee55e9fe1425f85b789e2216b328f0
  • Loading branch information
TF-Agents Team authored and Copybara-Service committed Mar 21, 2020
1 parent 3c72322 commit 7af6850
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 16 deletions.
20 changes: 14 additions & 6 deletions tf_agents/policies/policy_saver.py
Expand Up @@ -331,10 +331,15 @@ def save_checkpoint(self, export_dir):
without having to reload the saved_model, or saving multiple copies of the
`saved_model.pb` file.
The checkpoint is always named 'variables' without a counter added to it.
This makes is compatible with the checkpoint part of saved models, which
enables you to load a saved model made up from the graph part of a full
saved model and the variables part of a checkpoint.
The checkpoint is always created in the sub-directory 'variables/' and the
checkpoint file prefix used is 'variables'. The checkpoint files are as
follows:
* export_dir/variables/variables.index
* export_dir/variables/variables-xxxxx-of-xxxxx
This makes the files compatible with the checkpoint part of full saved
models, which enables you to load a saved model made up from the graph part
of a full saved model and the variables part of a checkpoint.
Args:
export_dir: Directory to save the checkpoint to.
Expand All @@ -343,11 +348,14 @@ def save_checkpoint(self, export_dir):
# train_step so the checkpoint can be combined with a saved graph from a
# full saved model.
checkpoint = tf.train.Checkpoint(
policy=self._policy, model_variables=self._policy.variables(),
policy=self._policy,
model_variables=self._policy.variables(),
train_step=self._train_step)
# Use write() to make sure that the file prefix is not modified by appending
# a save counter value.
checkpoint.write(file_prefix=os.path.join(export_dir, 'variables'))
checkpoint.write(
file_prefix=os.path.join(export_dir, tf.saved_model.VARIABLES_DIRECTORY,
tf.saved_model.VARIABLES_FILENAME))


def _function_with_flat_signature(function,
Expand Down
12 changes: 7 additions & 5 deletions tf_agents/policies/policy_saver_test.py
Expand Up @@ -570,8 +570,9 @@ def testUpdateWithCheckpoint(self):

# Update from checkpoint.
checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
checkpoint.read(os.path.join(
checkpoint_path, 'variables')).assert_existing_objects_matched()
checkpoint_file_prefix = os.path.join(checkpoint_path, 'variables',
'variables')
checkpoint.read(checkpoint_file_prefix).assert_existing_objects_matched()

self.evaluate(
tf.compat.v1.initializers.variables(reloaded_policy.model_variables))
Expand Down Expand Up @@ -629,8 +630,9 @@ def testInferenceWithCheckpoint(self):

# Update from checkpoint.
checkpoint = tf.train.Checkpoint(policy=reloaded_policy)
checkpoint.read(os.path.join(
checkpoint_path, 'variables')).assert_existing_objects_matched()
checkpoint_file_prefix = os.path.join(checkpoint_path, 'variables',
'variables')
checkpoint.read(checkpoint_file_prefix).assert_existing_objects_matched()

self.evaluate(
tf.compat.v1.initializers.variables(reloaded_policy.model_variables))
Expand Down Expand Up @@ -707,7 +709,7 @@ def assert_val_equal_var(val, var):
# and variables from the checkpoint.
composite_path = os.path.join(self.get_temp_dir(), 'composite_model')
self.copy_tree(full_model_path, composite_path, skip_variables=True)
self.copy_tree(checkpoint_path, os.path.join(composite_path, 'variables'))
self.copy_tree(checkpoint_path, os.path.join(composite_path))

# Reload the composite model and check all variables are 2
reloaded_policy = tf.compat.v2.saved_model.load(composite_path)
Expand Down
13 changes: 12 additions & 1 deletion tf_agents/policies/py_tf_eager_policy.py
Expand Up @@ -175,8 +175,19 @@ def variables(self):
def update_from_checkpoint(self, checkpoint_path):
"""Allows users to update saved_model variables directly from a checkpoint.
`checkpoint_path` is a path that was passed to either `PolicySaver.save()`
or `PolicySaver.save_checkpoint()`. The policy looks for set of checkpoint
files with the file prefix `<checkpoint_path>/variables/variables'
Args:
checkpoint_path: Path to the checkpoint to restore and use to udpate this
policy.
"""
self._checkpoint.read(checkpoint_path).assert_existing_objects_matched()
file_prefix = os.path.join(checkpoint_path,
tf.saved_model.VARIABLES_DIRECTORY,
tf.saved_model.VARIABLES_FILENAME)
status = self._checkpoint.read(file_prefix)
# Check that all the variables in the policy were updated, but allow the
# checkpoint to have additional variables. This helps sharing checkpoints
# across policies.
status.assert_existing_objects_matched().expect_partial()
6 changes: 2 additions & 4 deletions tf_agents/policies/py_tf_eager_policy_test.py
Expand Up @@ -221,8 +221,7 @@ def testUpdateFromCheckpoint(self):
# Use evaluate to force a copy.
saved_model_variables = self.evaluate(eager_py_policy.variables())

eager_py_policy.update_from_checkpoint(
os.path.join(checkpoint_path, 'variables'))
eager_py_policy.update_from_checkpoint(checkpoint_path)

assert_np_not_equal = lambda a, b: self.assertFalse(np.equal(a, b).all())
tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
Expand Down Expand Up @@ -257,8 +256,7 @@ def testInferenceFromCheckpoint(self):
# Use evaluate to force a copy.
saved_model_variables = self.evaluate(eager_py_policy.variables())

eager_py_policy.update_from_checkpoint(
os.path.join(checkpoint_path, 'variables'))
eager_py_policy.update_from_checkpoint(checkpoint_path)

assert_np_not_equal = lambda a, b: self.assertFalse(np.equal(a, b).all())
tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
Expand Down

0 comments on commit 7af6850

Please sign in to comment.