Skip to content

Commit

Permalink
Write exports to subdir in tests.
Browse files Browse the repository at this point in the history
Listing the top level dir is not stable across OSes.

PiperOrigin-RevId: 268556626
  • Loading branch information
EugenHotaj authored and cweill committed Sep 12, 2019
1 parent e081336 commit 77199f9
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions adanet/core/estimator_test.py
Expand Up @@ -1071,13 +1071,14 @@ def serving_input_fn():
export_saved_model_fn = getattr(estimator, "export_saved_model", None)
if not callable(export_saved_model_fn):
export_saved_model_fn = estimator.export_savedmodel
export_dir_base = os.path.join(self.test_subdirectory, "export")
export_saved_model_fn(
export_dir_base=self.test_subdirectory,
export_dir_base=export_dir_base,
serving_input_receiver_fn=serving_input_fn)
if export_subnetworks:
saved_model = saved_model_utils.read_saved_model(
os.path.join(self.test_subdirectory,
tf.io.gfile.listdir(self.test_subdirectory)[0]))
os.path.join(export_dir_base,
tf.io.gfile.listdir(export_dir_base)[0]))
export_signature_def = saved_model.meta_graphs[0].signature_def
self.assertIn("subnetwork_logits", export_signature_def.keys())
self.assertIn("subnetwork_last_layer", export_signature_def.keys())
Expand Down Expand Up @@ -1534,13 +1535,14 @@ def serving_input_fn():
export_saved_model_fn = getattr(estimator, "export_saved_model", None)
if not callable(export_saved_model_fn):
export_saved_model_fn = estimator.export_savedmodel
export_dir_base = os.path.join(self.test_subdirectory, "export")
export_saved_model_fn(
export_dir_base=self.test_subdirectory,
export_dir_base=export_dir_base,
serving_input_receiver_fn=serving_input_fn)
if export_subnetworks:
saved_model = saved_model_utils.read_saved_model(
os.path.join(self.test_subdirectory,
tf.io.gfile.listdir(self.test_subdirectory)[0]))
os.path.join(export_dir_base,
tf.io.gfile.listdir(export_dir_base)[0]))
export_signature_def = saved_model.meta_graphs[0].signature_def
self.assertIn("subnetwork_logits_head1", export_signature_def.keys())
self.assertIn("subnetwork_logits_head2", export_signature_def.keys())
Expand Down

0 comments on commit 77199f9

Please sign in to comment.