Skip to content

Commit

Permalink
Add support for caching modules on GCS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 199759306
  • Loading branch information
TensorFlow Hub Authors authored and andresusanopinto committed Jun 8, 2018
1 parent f7c1a46 commit 6860621
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 22 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

# Current version (0.2.0-dev)
* Under development.
* Add support for caching modules on GCS

# Release 0.1.0
* Initial TensorFlow Hub release.
Expand Down
76 changes: 58 additions & 18 deletions tensorflow_hub/e2e_test.py
Expand Up @@ -24,8 +24,8 @@
import tensorflow as tf
import tensorflow_hub as hub

from tensorflow_hub import resolver
from tensorflow_hub import test_utils
from tensorflow_hub import tf_utils


class End2EndTest(tf.test.TestCase):
Expand Down Expand Up @@ -54,7 +54,7 @@ def _list_module_files(self, module_dir):
files.append(f)
return files

def testHttpLocations(self):
def test_http_locations(self):
spec = hub.create_module_spec(self._stateless_module_fn)
m = hub.Module(spec, name="test_module")
out = m(10)
Expand All @@ -77,23 +77,63 @@ def testHttpLocations(self):
with tf.Session() as sess:
self.assertAllClose(sess.run(out), 121)

def testUnknownHandleFormat(self):
# Test caching using custom filesystem (file://) to make sure that the
# TF Hub library can operate on such paths.
try:
hub.Module("s3://my_module.zip")
except resolver.UnsupportedHandleError as e:
self.assertStartsWith(
str(e), "unsupported handle format 's3://my_module.zip'. No "
"resolvers found that can successfully resolve it.")

try:
non_existant_module = os.path.join(self.get_temp_dir(), "missing_module")
hub.Module(non_existant_module)
except resolver.UnsupportedHandleError as e:
self.assertStartsWith(
str(e), "unsupported handle format '%s'. No "
"resolvers found that can successfully resolve it." %
non_existant_module)

root_dir = "file://%s" % self.get_temp_dir()
cache_dir = "%s_%s" % (root_dir, "cache")
tf.gfile.MakeDirs(cache_dir)
os.environ["TFHUB_CACHE_DIR"] = cache_dir
m = hub.Module("http://localhost:%d/test_module.tgz" % self.server_port)
out = m(11)
with tf.train.MonitoredSession() as sess:
self.assertAllClose(sess.run(out), 121)

cache_content = sorted(tf.gfile.ListDirectory(cache_dir))
tf.logging.info("Cache context: %s", str(cache_content))
self.assertEqual(2, len(cache_content))
self.assertTrue(cache_content[1].endswith(".descriptor.txt"))
module_files = sorted(tf.gfile.ListDirectory(
os.path.join(cache_dir, cache_content[0])))
self.assertListEqual(["saved_model.pb", "tfhub_module.pb"], module_files)
finally:
os.unsetenv("TFHUB_CACHE_DIR")

def test_module_export_vocab_on_custom_fs(self):
root_dir = "file://%s" % self.get_temp_dir()
export_dir = "%s_%s" % (root_dir, "export")
tf.gfile.MakeDirs(export_dir)
# Create a module with a vocab file located on a custom filesystem.
vocab_dir = os.path.join(root_dir, "vocab_location")
tf.gfile.MakeDirs(vocab_dir)
vocab_filename = os.path.join(vocab_dir, "tokens.txt")
tf_utils.atomic_write_string_to_file(vocab_filename, "one", False)

def create_assets_module_fn():

def assets_module_fn():
indices = tf.placeholder(dtype=tf.int64, name="indices")
table = tf.contrib.lookup.index_to_string_table_from_file(
vocabulary_file=vocab_filename, default_value="UNKNOWN")
outputs = table.lookup(indices)
hub.add_signature(inputs=indices, outputs=outputs)

return assets_module_fn

with tf.Graph().as_default():
assets_module_fn = create_assets_module_fn()
spec = hub.create_module_spec(assets_module_fn)
embedding_module = hub.Module(spec)
with tf.Session() as sess:
sess.run(tf.tables_initializer())
embedding_module.export(export_dir, sess)

module_files = tf.gfile.ListDirectory(export_dir)
self.assertListEqual(
["assets", "saved_model.pb", "tfhub_module.pb", "variables"],
sorted(module_files))
module_files = tf.gfile.ListDirectory(os.path.join(export_dir, "assets"))
self.assertListEqual(["tokens.txt"], module_files)

if __name__ == "__main__":
tf.test.main()
4 changes: 2 additions & 2 deletions tensorflow_hub/resolver.py
Expand Up @@ -223,7 +223,7 @@ def _lock_file_contents(task_uid):

def _lock_filename(module_dir):
"""Returns lock file name."""
return os.path.abspath(module_dir) + ".lock"
return tf_utils.absolute_path(module_dir) + ".lock"


def _module_dir(lock_filename):
Expand All @@ -249,7 +249,7 @@ def _task_uid_from_lock_file(lock_filename):

def _temp_download_dir(module_dir, task_uid):
"""Returns the name of a temporary directory to download module to."""
return "{}.{}.tmp".format(os.path.abspath(module_dir), task_uid)
return "{}.{}.tmp".format(tf_utils.absolute_path(module_dir), task_uid)


def _dir_size(directory):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_hub/saved_model_lib.py
Expand Up @@ -57,8 +57,8 @@ def _get_asset_filename(export_dir, asset_filename):
filename = os.path.join(
tf.compat.as_bytes(assets_dir),
tf.compat.as_bytes(asset_filename))
if not os.path.abspath(filename).startswith(
os.path.abspath(assets_dir)):
if not tf_utils.absolute_path(filename).startswith(
tf_utils.absolute_path(assets_dir)):
raise ValueError(
"Asset filename (%s) points outside assets_dir" % asset_filename)
return filename
Expand Down
12 changes: 12 additions & 0 deletions tensorflow_hub/tf_utils.py
Expand Up @@ -187,3 +187,15 @@ def bytes_to_readable_str(num_bytes, include_b=False):
if include_b:
result += "B"
return result


def absolute_path(path):
"""Returns absolute path.
Args:
path: Path to compute absolute path from.
This implementation avoids calling os.path.abspath(path) if 'path' already
represents an absolute Tensorflow filesystem location (e.g. <fs type>://).
"""
return path if "://" in str(path) else os.path.abspath(path)

0 comments on commit 6860621

Please sign in to comment.