Skip to content

Commit

Permalink
Feed asset paths in tf.save_model.load when loading 1.x SavedModels
Browse files Browse the repository at this point in the history
Adds a simple table test.

PiperOrigin-RevId: 233773824
  • Loading branch information
allenlavoie authored and tensorflower-gardener committed Feb 13, 2019
1 parent a06adad commit 24ed029
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 5 deletions.
13 changes: 10 additions & 3 deletions tensorflow/python/saved_model/load_v1_in_v2.py
Expand Up @@ -95,9 +95,16 @@ def load(self, tags):
with wrapped.graph.as_default():
init_op = loader_impl.get_init_op(meta_graph_def)
if init_op is not None:
# TODO(allenl): Deal with assets
wrapped.prune(feeds=[],
fetches=[wrapped.graph.as_graph_element(init_op)])()
asset_feed_tensors = []
asset_paths = []
for tensor_name, value in loader_impl.get_asset_tensors(
self._export_dir, meta_graph_def).items():
asset_feed_tensors.append(wrapped.graph.as_graph_element(tensor_name))
asset_paths.append(tracking.TrackableAsset(value))
init_fn = wrapped.prune(
feeds=asset_feed_tensors,
fetches=[wrapped.graph.as_graph_element(init_op)])
init_fn(*[path.asset_path for path in asset_paths])
signature_functions = self._extract_signatures(wrapped, meta_graph_def)
root = tracking.AutoCheckpointable()
root.signatures = signature_serialization.create_signature_map(
Expand Down
51 changes: 51 additions & 0 deletions tensorflow/python/saved_model/load_v1_in_v2_test.py
Expand Up @@ -19,18 +19,22 @@
from __future__ import print_function

import os
import shutil

from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import builder_impl
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import simple_save
from tensorflow.python.saved_model import utils_impl
Expand Down Expand Up @@ -149,6 +153,53 @@ def test_multi_meta_graph_loading(self):
self.evaluate(second_imported.signatures["second_key"](
second_start=constant_op.constant(2.))))

def _v1_asset_saved_model(self):
export_graph = ops.Graph()
vocab_path = os.path.join(self.get_temp_dir(), "vocab.txt")
with open(vocab_path, "w") as f:
f.write("alpha\nbeta\ngamma\n")
with export_graph.as_default():
initializer = lookup_ops.TextFileInitializer(
vocab_path,
key_dtype=dtypes.string,
key_index=lookup_ops.TextFileIndex.WHOLE_LINE,
value_dtype=dtypes.int64,
value_index=lookup_ops.TextFileIndex.LINE_NUMBER)
table = lookup_ops.HashTable(
initializer, default_value=-1)
start = array_ops.placeholder(
shape=None, dtype=dtypes.string, name="in")
output = table.lookup(start, name="out")
with session_lib.Session() as session:
session.run([table.initializer])
path = os.path.join(self.get_temp_dir(), "saved_model", str(ops.uid()))
simple_save.simple_save(
session,
path,
inputs={"start": start},
outputs={"output": output},
legacy_init_op=table.initializer)
file_io.delete_file(vocab_path)
return path

def test_asset_loading(self):
first_path = self._v1_asset_saved_model()
imported = load.load(first_path)
fn = imported.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))
second_path = os.path.join(self.get_temp_dir(), "saved_model",
str(ops.uid()))
save.save(imported, second_path, signatures=imported.signatures)
shutil.rmtree(first_path)
self.skipTest(
"TODO(b/124321570): save TrackableAssets and make re-saving initialize "
"correctly")
second_import = load.load(second_path)
fn = second_import.signatures["serving_default"]
self.assertAllClose({"output": [2, 0]},
fn(start=constant_op.constant(["gamma", "alpha"])))


if __name__ == "__main__":
test.main()
4 changes: 2 additions & 2 deletions tensorflow/python/saved_model/loader_impl.py
Expand Up @@ -88,7 +88,7 @@ def parse_saved_model(export_dir):
_parse_saved_model = parse_saved_model


def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
def get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
Expand Down Expand Up @@ -393,7 +393,7 @@ def run_init_ops(self, sess, tags, import_scope=None):
meta_graph_def = self.get_meta_graph_def_from_tags(tags)
with sess.graph.as_default():
# Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(
asset_tensors_dictionary = get_asset_tensors(
self._export_dir, meta_graph_def, import_scope=import_scope)

init_op = get_init_op(meta_graph_def, import_scope)
Expand Down

0 comments on commit 24ed029

Please sign in to comment.