Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tfjs-converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ saved a tf.keras model in the SavedModel format.
|`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
|<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|

__Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__
Expand Down
70 changes: 48 additions & 22 deletions tfjs-converter/python/tensorflowjs/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
h5_path, output_dir=None,
quantization_dtype=None,
skip_op_check=False,
strip_debug_ops=False):
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""
Convert a keras HDF5-format model to tfjs GraphModel artifacts.

Expand All @@ -117,6 +118,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
(Default: `None`).
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
"""

if not os.path.exists(h5_path):
Expand All @@ -138,15 +141,17 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
saved_model_tags='serve',
quantization_dtype=quantization_dtype,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops)
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)

# Clean up the temporary SavedModel directory.
shutil.rmtree(temp_savedmodel_dir)


def dispatch_keras_saved_model_to_tensorflowjs_conversion(
keras_saved_model_path, output_dir, quantization_dtype=None,
split_weights_by_layer=False):
split_weights_by_layer=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts keras model saved in the SavedModel format to tfjs format.

Note that the SavedModel format exists in keras, but not in
Expand All @@ -166,6 +171,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
split_weights_by_layer: Whether to split the weights into separate weight
groups (corresponding to separate binary weight files) layer by layer
(Default: `False`).
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
"""
with tf.Graph().as_default(), tf.compat.v1.Session():
model = tf.keras.models.load_model(keras_saved_model_path)
Expand All @@ -179,7 +186,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
temp_h5_path,
output_dir,
quantization_dtype=quantization_dtype,
split_weights_by_layer=split_weights_by_layer)
split_weights_by_layer=split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)

# Delete temporary .h5 file.
os.remove(temp_h5_path)
Expand Down Expand Up @@ -321,7 +329,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
output_dir_path,
quantization_dtype=None,
skip_op_check=False,
strip_debug_ops=False):
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts a TensorFlow.js Layers Model to TensorFlow.js Graph Model.

This conversion often benefits speed of inference, due to the graph
Expand All @@ -336,6 +345,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
(Default: `None`).
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.

Raises:
ValueError, if `config_json_path` is not a path to a valid JSON
Expand Down Expand Up @@ -369,7 +380,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
temp_h5_path, output_dir_path,
quantization_dtype=quantization_dtype,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops)
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)

# Clean up temporary HDF5 file.
os.remove(temp_h5_path)
Expand Down Expand Up @@ -507,7 +519,7 @@ def get_arg_parser():
type=int,
default=None,
help='Shard size (in bytes) of the weight files. Currently applicable '
'only to output_format=tfjs_layers_model.')
'only when output_format is tfjs_layers_model or tfjs_graph_model.')
parser.add_argument(
'--output_node_names',
type=str,
Expand All @@ -532,14 +544,6 @@ def convert(arguments):
raise ValueError(
'Missing output_path argument. For usage, use the --help flag.')

weight_shard_size_bytes = 1024 * 1024 * 4
if args.weight_shard_size_bytes:
if args.output_format != common.TFJS_LAYERS_MODEL:
raise ValueError(
'The --weight_shard_size_bytes flag is only supported under '
'output_format=tfjs_layers_model.')
weight_shard_size_bytes = args.weight_shard_size_bytes

if args.input_path is None:
raise ValueError(
'Error: The input_path argument must be set. '
Expand All @@ -548,6 +552,21 @@ def convert(arguments):
input_format, output_format = _standardize_input_output_formats(
args.input_format, args.output_format)

weight_shard_size_bytes = 1024 * 1024 * 4
if args.weight_shard_size_bytes is not None:
if (output_format not in
(common.TFJS_LAYERS_MODEL, common.TFJS_GRAPH_MODEL)):
raise ValueError(
'The --weight_shard_size_bytes flag is only supported when '
'output_format is tfjs_layers_model or tfjs_graph_model.')

if not (isinstance(args.weight_shard_size_bytes, int) and
args.weight_shard_size_bytes > 0):
raise ValueError(
'Expected weight_shard_size_bytes to be a positive integer, '
'but got %s' % args.weight_shard_size_bytes)
weight_shard_size_bytes = args.weight_shard_size_bytes

quantization_dtype = (
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
if args.quantization_bytes else None)
Expand All @@ -570,20 +589,23 @@ def convert(arguments):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.KERAS_SAVED_MODEL and
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_saved_model_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=quantization_dtype,
split_weights_by_layer=args.split_weights_by_layer)
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_SAVED_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_saved_model(
Expand All @@ -592,7 +614,8 @@ def convert(arguments):
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_HUB_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_hub_module(
Expand All @@ -601,7 +624,8 @@ def convert(arguments):
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.KERAS_MODEL):
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
Expand All @@ -622,14 +646,16 @@ def convert(arguments):
args.input_path, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_FROZEN_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_frozen_model(
args.input_path, args.output_node_names, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops)
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
else:
raise ValueError(
'Unsupported input_format - output_format pair: %s - %s' %
Expand Down
63 changes: 63 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,35 @@ def testConvertSavedKerasModelSplitByLayer(self):
self.assertIsInstance(output_json['weightsManifest'], list)
self.assertTrue(glob.glob(os.path.join(self._tmp_dir, 'group*-*')))

def testConvertSavedKerasModeltoTfLayersModelSharded(self):
with tf.Graph().as_default(), tf.compat.v1.Session():
sequential_model = keras.models.Sequential([
keras.layers.Dense(
3, input_shape=(2,), use_bias=True, kernel_initializer='ones',
name='Dense1')])
h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5')
sequential_model.save(h5_path)

weights = sequential_model.get_weights()
total_weight_bytes = sum(np.size(w) for w in weights) * 4

# Due to the shard size, there ought to be 4 shards after conversion.
weight_shard_size_bytes = int(total_weight_bytes * 0.3)

# Convert Keras model to tfjs_layers_model format.
output_dir = os.path.join(self._tmp_dir, 'sharded_tfjs')
converter.dispatch_keras_h5_to_tfjs_layers_model_conversion(
h5_path, output_dir,
weight_shard_size_bytes=weight_shard_size_bytes)

weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin')))
self.assertEqual(len(weight_files), 4)
weight_file_sizes = [os.path.getsize(f) for f in weight_files]
self.assertEqual(sum(weight_file_sizes), total_weight_bytes)
self.assertEqual(weight_file_sizes[0], weight_file_sizes[1])
self.assertEqual(weight_file_sizes[0], weight_file_sizes[2])
self.assertLess(weight_file_sizes[3], weight_file_sizes[0])

def testConvertWeightsFromSequentialModel(self):
with tf.Graph().as_default(), tf.compat.v1.Session():
sequential_model = keras.models.Sequential([
Expand Down Expand Up @@ -318,6 +347,40 @@ def testConvertKerasModelToTfGraphModel(self):
tf.__version__)
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))

def testConvertKerasModelToTfGraphModelSharded(self):
output_dir = os.path.join(self._tmp_dir, 'foo_model')
sequential_model = keras.models.Sequential([
keras.layers.Dense(
3, input_shape=(2,), use_bias=True, kernel_initializer='ones',
name='Dense1')])
h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5')
sequential_model.save(h5_path)

# Do initial conversion without sharding.
converter.dispatch_keras_h5_to_tfjs_graph_model_conversion(
h5_path, output_dir)
weight_files = glob.glob(os.path.join(output_dir, 'group*.bin'))

# Get size of weights in bytes after graph optimizations.
optimized_total_weight = sum([os.path.getsize(f) for f in weight_files])

# Due to the shard size, there ought to be 4 shards after conversion.
weight_shard_size_bytes = int(optimized_total_weight * 0.3)

output_dir = os.path.join(self._tmp_dir, 'sharded_model')
# Convert Keras model again with shard argument set.
converter.dispatch_keras_h5_to_tfjs_graph_model_conversion(
h5_path, output_dir,
weight_shard_size_bytes=weight_shard_size_bytes)

weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin')))
self.assertEqual(len(weight_files), 4)
weight_file_sizes = [os.path.getsize(f) for f in weight_files]
self.assertEqual(sum(weight_file_sizes), optimized_total_weight)
self.assertEqual(weight_file_sizes[0], weight_file_sizes[1])
self.assertEqual(weight_file_sizes[0], weight_file_sizes[2])
self.assertLess(weight_file_sizes[3], weight_file_sizes[0])


class ConvertTfKerasSavedModelTest(tf.test.TestCase):

Expand Down
Loading