Skip to content

Commit

Permalink
Make sure async_policy_saver gets closed in tests to avoid hanging in…
Browse files Browse the repository at this point in the history
… OSS tests.

PiperOrigin-RevId: 302460421
Change-Id: Ie1a52a7b29b9bb25da7dec69ad1a89273f2392f0
  • Loading branch information
Oscar Ramirez authored and Copybara-Service committed Mar 23, 2020
1 parent 7af6850 commit 93c6b1b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
1 change: 0 additions & 1 deletion broken_tests.txt
@@ -1,2 +1 @@
replay_buffers.tfrecord_replay_buffer_test # b/140896267
policies.async_policy_saver_test # b/151865318
12 changes: 12 additions & 0 deletions tf_agents/policies/async_policy_saver_test.py
Expand Up @@ -38,6 +38,9 @@ def testSave(self):
async_saver.flush()

saver.save.assert_called_once_with(save_path)
# Have to close the saver to avoid hanging threads that will prevent OSS
# tests from finishing.
async_saver.close()

def testCheckpointSave(self):
saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
Expand All @@ -52,6 +55,9 @@ def testCheckpointSave(self):
async_saver.flush()

saver.save_checkpoint.assert_called_once_with(checkpoint_path)
# Have to close the saver to avoid hanging threads that will prevent OSS
# tests from finishing.
async_saver.close()

def testBlockingSave(self):
saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
Expand All @@ -64,6 +70,9 @@ def testBlockingSave(self):
async_saver.save(path2, blocking=True)

saver.save.assert_has_calls([mock.call(path1), mock.call(path2)])
# Have to close the saver to avoid hanging threads that will prevent OSS
# tests from finishing.
async_saver.close()

def testBlockingCheckpointSave(self):
saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
Expand All @@ -76,6 +85,9 @@ def testBlockingCheckpointSave(self):
async_saver.save_checkpoint(path2, blocking=True)

saver.save_checkpoint.assert_has_calls([mock.call(path1), mock.call(path2)])
# Have to close the saver to avoid hanging threads that will prevent OSS
# tests from finishing.
async_saver.close()

def testClose(self):
saver = mock.create_autospec(policy_saver.PolicySaver, instance=True)
Expand Down

0 comments on commit 93c6b1b

Please sign in to comment.