Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tfjs-converter/python/run-python-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ TEST_FILES="$(find "${SCRIPTS_DIR}" -name '*_test.py')"

pip install virtualenv

TMP_VENV_DIR="$(mktemp -d --suffix=_venv)"
TMP_VENV_DIR="$(mktemp -u).venv"
virtualenv -p "python" "${TMP_VENV_DIR}"
source "${TMP_VENV_DIR}/bin/activate"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.saved_model.load import load
from tensorflow.python.saved_model import loader
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training.saver import export_meta_graph
from google.protobuf.json_format import MessageToDict
import tensorflow_hub as hub
Expand Down Expand Up @@ -272,15 +274,20 @@ def _check_signature_in_model(saved_model, signature_name):
saved_model.signatures.keys()))


def _freeze_saved_model_v1(graph, output_node_names):
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
tf.compat.v1.Session(), graph.as_graph_def(), output_node_names)
def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names):
with tf.compat.v1.Session() as sess:
loader.load(sess, saved_model_tags, saved_model_dir)
input_graph_def = saved_model_utils.get_meta_graph_def(
saved_model_dir, ','.join(saved_model_tags)).graph_def
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names)

frozen_graph = tf.Graph()
with frozen_graph.as_default():
tf.import_graph_def(frozen_graph_def, name='')
frozen_graph = tf.Graph()
with frozen_graph.as_default():
tf.import_graph_def(frozen_graph_def, name='')

return frozen_graph
return frozen_graph

def _freeze_saved_model_v2(concrete_func):
return convert_to_constants.convert_variables_to_constants_v2(
Expand Down Expand Up @@ -336,8 +343,8 @@ def convert_tf_saved_model(saved_model_dir,
try:
frozen_graph = _freeze_saved_model_v2(concrete_func)
except BaseException:
frozen_graph = _freeze_saved_model_v1(
concrete_func.graph, output_node_names)
frozen_graph = _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
output_node_names)

optimize_graph(frozen_graph, output_node_names, output_graph,
model.tensorflow_version,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,44 @@ def _create_saved_model_v1(self):

builder.save()

def _create_saved_model_v1_with_hashtable(self):
"""Create a TensorFlow SavedModel V1 with unused hash table for testing."""

graph = tf.Graph()
with graph.as_default():
x = tf.placeholder('float32', [2, 2])
w = tf.compat.v1.get_variable('w', shape=[2, 2])
output = tf.compat.v1.matmul(x, w)
init_op = w.initializer

# Add a hash table that is not used by the output.
keys = tf.constant(['key'])
values = tf.constant([1])
initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
table = tf.contrib.lookup.HashTable(initializer, -1)

# Create a builder.
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(save_dir)

with tf.compat.v1.Session() as sess:
# Run the initializer on `w`.
sess.run(init_op)
table.init.run()

builder.add_meta_graph_and_variables(
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={
"serving_default":
tf.compat.v1.saved_model \
.signature_def_utils.predict_signature_def(
inputs={"x": x},
outputs={"output": output})
},
assets_collection=None)

builder.save()

def _create_saved_model_with_fusable_conv2d(self):
"""Test a basic model with fusable conv2d."""
layers = [
Expand Down Expand Up @@ -192,32 +230,62 @@ def double_module_fn():
def test_convert_saved_model_v1(self):
self._create_saved_model_v1()

input_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
output_dir = os.path.join(input_dir, 'js')
tf_saved_model_conversion_v2.convert_tf_saved_model(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR),
os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
input_dir,
output_dir
)

weights = [{
expected_weights_manifest = [{
'paths': ['group1-shard1of1.bin'],
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
# Check model.json and weights manifest.
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
model_json = json.load(f)
self.assertTrue(model_json['modelTopology'])
weights_manifest = model_json['weightsManifest']
self.assertEqual(weights_manifest, weights)
self.assertEqual(weights_manifest, expected_weights_manifest)
# Check meta-data in the artifact JSON.
self.assertEqual(model_json['format'], 'graph-model')
self.assertEqual(
model_json['convertedBy'],
'TensorFlow.js Converter v%s' % version.version)
self.assertEqual(model_json['generatedBy'],
tf.__version__)
self.assertTrue(
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))

def test_convert_saved_model_v1_with_hashtable(self):
self._create_saved_model_v1_with_hashtable()

input_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
output_dir = os.path.join(input_dir, 'js')
tf_saved_model_conversion_v2.convert_tf_saved_model(
input_dir,
output_dir
)

expected_weights_manifest = [{
'paths': ['group1-shard1of1.bin'],
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
# Check model.json and weights manifest.
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
model_json = json.load(f)
self.assertTrue(model_json['modelTopology'])
weights_manifest = model_json['weightsManifest']
self.assertEqual(weights_manifest, expected_weights_manifest)
# Check meta-data in the artifact JSON.
self.assertEqual(model_json['format'], 'graph-model')
self.assertEqual(
model_json['convertedBy'],
'TensorFlow.js Converter v%s' % version.version)
self.assertEqual(model_json['generatedBy'],
tf.__version__)
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))

def test_convert_saved_model(self):
self._create_saved_model()
Expand Down