Skip to content

Commit af46b87

Browse files
pvaneckpyu10055
andauthored
[tfjs-converter] Allow custom shard size for more conversion pairs (#2774)
FEATURE * Allow custom shard size for more conversion pairs * Adjust arg size arg check * Add tests for sharded argument * Update readme with new arg * Add shard size support to tensorflowjs_wizard Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
1 parent f058075 commit af46b87

File tree

9 files changed

+271
-43
lines changed

9 files changed

+271
-43
lines changed

tfjs-converter/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ saved a tf.keras model in the SavedModel format.
126126
|`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
127127
|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
128128
|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
129+
|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
129130
|<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|
130131

131132
__Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__

tfjs-converter/python/tensorflowjs/converters/converter.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
104104
h5_path, output_dir=None,
105105
quantization_dtype=None,
106106
skip_op_check=False,
107-
strip_debug_ops=False):
107+
strip_debug_ops=False,
108+
weight_shard_size_bytes=1024 * 1024 * 4):
108109
"""
109110
Convert a keras HDF5-format model to tfjs GraphModel artifacts.
110111
@@ -117,6 +118,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
117118
(Default: `None`).
118119
skip_op_check: Bool whether to skip the op check.
119120
strip_debug_ops: Bool whether to allow unsupported debug ops.
121+
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
122+
The size of each weight file will be <= this value.
120123
"""
121124

122125
if not os.path.exists(h5_path):
@@ -138,15 +141,17 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
138141
saved_model_tags='serve',
139142
quantization_dtype=quantization_dtype,
140143
skip_op_check=skip_op_check,
141-
strip_debug_ops=strip_debug_ops)
144+
strip_debug_ops=strip_debug_ops,
145+
weight_shard_size_bytes=weight_shard_size_bytes)
142146

143147
# Clean up the temporary SavedModel directory.
144148
shutil.rmtree(temp_savedmodel_dir)
145149

146150

147151
def dispatch_keras_saved_model_to_tensorflowjs_conversion(
148152
keras_saved_model_path, output_dir, quantization_dtype=None,
149-
split_weights_by_layer=False):
153+
split_weights_by_layer=False,
154+
weight_shard_size_bytes=1024 * 1024 * 4):
150155
"""Converts keras model saved in the SavedModel format to tfjs format.
151156
152157
Note that the SavedModel format exists in keras, but not in
@@ -166,6 +171,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
166171
split_weights_by_layer: Whether to split the weights into separate weight
167172
groups (corresponding to separate binary weight files) layer by layer
168173
(Default: `False`).
174+
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
175+
The size of each weight file will be <= this value.
169176
"""
170177
with tf.Graph().as_default(), tf.compat.v1.Session():
171178
model = tf.keras.models.load_model(keras_saved_model_path)
@@ -179,7 +186,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
179186
temp_h5_path,
180187
output_dir,
181188
quantization_dtype=quantization_dtype,
182-
split_weights_by_layer=split_weights_by_layer)
189+
split_weights_by_layer=split_weights_by_layer,
190+
weight_shard_size_bytes=weight_shard_size_bytes)
183191

184192
# Delete temporary .h5 file.
185193
os.remove(temp_h5_path)
@@ -321,7 +329,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
321329
output_dir_path,
322330
quantization_dtype=None,
323331
skip_op_check=False,
324-
strip_debug_ops=False):
332+
strip_debug_ops=False,
333+
weight_shard_size_bytes=1024 * 1024 * 4):
325334
"""Converts a TensorFlow.js Layers Model to TensorFlow.js Graph Model.
326335
327336
This conversion often benefits speed of inference, due to the graph
@@ -336,6 +345,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
336345
(Default: `None`).
337346
skip_op_check: Bool whether to skip the op check.
338347
strip_debug_ops: Bool whether to allow unsupported debug ops.
348+
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
349+
The size of each weight file will be <= this value.
339350
340351
Raises:
341352
ValueError, if `config_json_path` is not a path to a valid JSON
@@ -369,7 +380,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
369380
temp_h5_path, output_dir_path,
370381
quantization_dtype=quantization_dtype,
371382
skip_op_check=skip_op_check,
372-
strip_debug_ops=strip_debug_ops)
383+
strip_debug_ops=strip_debug_ops,
384+
weight_shard_size_bytes=weight_shard_size_bytes)
373385

374386
# Clean up temporary HDF5 file.
375387
os.remove(temp_h5_path)
@@ -507,7 +519,7 @@ def get_arg_parser():
507519
type=int,
508520
default=None,
509521
help='Shard size (in bytes) of the weight files. Currently applicable '
510-
'only to output_format=tfjs_layers_model.')
522+
'only when output_format is tfjs_layers_model or tfjs_graph_model.')
511523
parser.add_argument(
512524
'--output_node_names',
513525
type=str,
@@ -532,14 +544,6 @@ def convert(arguments):
532544
raise ValueError(
533545
'Missing output_path argument. For usage, use the --help flag.')
534546

535-
weight_shard_size_bytes = 1024 * 1024 * 4
536-
if args.weight_shard_size_bytes:
537-
if args.output_format != common.TFJS_LAYERS_MODEL:
538-
raise ValueError(
539-
'The --weight_shard_size_bytes flag is only supported under '
540-
'output_format=tfjs_layers_model.')
541-
weight_shard_size_bytes = args.weight_shard_size_bytes
542-
543547
if args.input_path is None:
544548
raise ValueError(
545549
'Error: The input_path argument must be set. '
@@ -548,6 +552,21 @@ def convert(arguments):
548552
input_format, output_format = _standardize_input_output_formats(
549553
args.input_format, args.output_format)
550554

555+
weight_shard_size_bytes = 1024 * 1024 * 4
556+
if args.weight_shard_size_bytes is not None:
557+
if (output_format not in
558+
(common.TFJS_LAYERS_MODEL, common.TFJS_GRAPH_MODEL)):
559+
raise ValueError(
560+
'The --weight_shard_size_bytes flag is only supported when '
561+
'output_format is tfjs_layers_model or tfjs_graph_model.')
562+
563+
if not (isinstance(args.weight_shard_size_bytes, int) and
564+
args.weight_shard_size_bytes > 0):
565+
raise ValueError(
566+
'Expected weight_shard_size_bytes to be a positive integer, '
567+
'but got %s' % args.weight_shard_size_bytes)
568+
weight_shard_size_bytes = args.weight_shard_size_bytes
569+
551570
quantization_dtype = (
552571
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
553572
if args.quantization_bytes else None)
@@ -570,20 +589,23 @@ def convert(arguments):
570589
dispatch_keras_h5_to_tfjs_layers_model_conversion(
571590
args.input_path, output_dir=args.output_path,
572591
quantization_dtype=quantization_dtype,
573-
split_weights_by_layer=args.split_weights_by_layer)
592+
split_weights_by_layer=args.split_weights_by_layer,
593+
weight_shard_size_bytes=weight_shard_size_bytes)
574594
elif (input_format == common.KERAS_MODEL and
575595
output_format == common.TFJS_GRAPH_MODEL):
576596
dispatch_keras_h5_to_tfjs_graph_model_conversion(
577597
args.input_path, output_dir=args.output_path,
578598
quantization_dtype=quantization_dtype,
579599
skip_op_check=args.skip_op_check,
580-
strip_debug_ops=args.strip_debug_ops)
600+
strip_debug_ops=args.strip_debug_ops,
601+
weight_shard_size_bytes=weight_shard_size_bytes)
581602
elif (input_format == common.KERAS_SAVED_MODEL and
582603
output_format == common.TFJS_LAYERS_MODEL):
583604
dispatch_keras_saved_model_to_tensorflowjs_conversion(
584605
args.input_path, args.output_path,
585606
quantization_dtype=quantization_dtype,
586-
split_weights_by_layer=args.split_weights_by_layer)
607+
split_weights_by_layer=args.split_weights_by_layer,
608+
weight_shard_size_bytes=weight_shard_size_bytes)
587609
elif (input_format == common.TF_SAVED_MODEL and
588610
output_format == common.TFJS_GRAPH_MODEL):
589611
tf_saved_model_conversion_v2.convert_tf_saved_model(
@@ -592,7 +614,8 @@ def convert(arguments):
592614
saved_model_tags=args.saved_model_tags,
593615
quantization_dtype=quantization_dtype,
594616
skip_op_check=args.skip_op_check,
595-
strip_debug_ops=args.strip_debug_ops)
617+
strip_debug_ops=args.strip_debug_ops,
618+
weight_shard_size_bytes=weight_shard_size_bytes)
596619
elif (input_format == common.TF_HUB_MODEL and
597620
output_format == common.TFJS_GRAPH_MODEL):
598621
tf_saved_model_conversion_v2.convert_tf_hub_module(
@@ -601,7 +624,8 @@ def convert(arguments):
601624
saved_model_tags=args.saved_model_tags,
602625
quantization_dtype=quantization_dtype,
603626
skip_op_check=args.skip_op_check,
604-
strip_debug_ops=args.strip_debug_ops)
627+
strip_debug_ops=args.strip_debug_ops,
628+
weight_shard_size_bytes=weight_shard_size_bytes)
605629
elif (input_format == common.TFJS_LAYERS_MODEL and
606630
output_format == common.KERAS_MODEL):
607631
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
@@ -622,14 +646,16 @@ def convert(arguments):
622646
args.input_path, args.output_path,
623647
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
624648
skip_op_check=args.skip_op_check,
625-
strip_debug_ops=args.strip_debug_ops)
649+
strip_debug_ops=args.strip_debug_ops,
650+
weight_shard_size_bytes=weight_shard_size_bytes)
626651
elif (input_format == common.TF_FROZEN_MODEL and
627652
output_format == common.TFJS_GRAPH_MODEL):
628653
tf_saved_model_conversion_v2.convert_tf_frozen_model(
629654
args.input_path, args.output_node_names, args.output_path,
630655
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
631656
skip_op_check=args.skip_op_check,
632-
strip_debug_ops=args.strip_debug_ops)
657+
strip_debug_ops=args.strip_debug_ops,
658+
weight_shard_size_bytes=weight_shard_size_bytes)
633659
else:
634660
raise ValueError(
635661
'Unsupported input_format - output_format pair: %s - %s' %

tfjs-converter/python/tensorflowjs/converters/converter_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,35 @@ def testConvertSavedKerasModelSplitByLayer(self):
151151
self.assertIsInstance(output_json['weightsManifest'], list)
152152
self.assertTrue(glob.glob(os.path.join(self._tmp_dir, 'group*-*')))
153153

154+
def testConvertSavedKerasModeltoTfLayersModelSharded(self):
155+
with tf.Graph().as_default(), tf.compat.v1.Session():
156+
sequential_model = keras.models.Sequential([
157+
keras.layers.Dense(
158+
3, input_shape=(2,), use_bias=True, kernel_initializer='ones',
159+
name='Dense1')])
160+
h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5')
161+
sequential_model.save(h5_path)
162+
163+
weights = sequential_model.get_weights()
164+
total_weight_bytes = sum(np.size(w) for w in weights) * 4
165+
166+
# Due to the shard size, there ought to be 4 shards after conversion.
167+
weight_shard_size_bytes = int(total_weight_bytes * 0.3)
168+
169+
# Convert Keras model to tfjs_layers_model format.
170+
output_dir = os.path.join(self._tmp_dir, 'sharded_tfjs')
171+
converter.dispatch_keras_h5_to_tfjs_layers_model_conversion(
172+
h5_path, output_dir,
173+
weight_shard_size_bytes=weight_shard_size_bytes)
174+
175+
weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin')))
176+
self.assertEqual(len(weight_files), 4)
177+
weight_file_sizes = [os.path.getsize(f) for f in weight_files]
178+
self.assertEqual(sum(weight_file_sizes), total_weight_bytes)
179+
self.assertEqual(weight_file_sizes[0], weight_file_sizes[1])
180+
self.assertEqual(weight_file_sizes[0], weight_file_sizes[2])
181+
self.assertLess(weight_file_sizes[3], weight_file_sizes[0])
182+
154183
def testConvertWeightsFromSequentialModel(self):
155184
with tf.Graph().as_default(), tf.compat.v1.Session():
156185
sequential_model = keras.models.Sequential([
@@ -318,6 +347,40 @@ def testConvertKerasModelToTfGraphModel(self):
318347
tf.__version__)
319348
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))
320349

350+
def testConvertKerasModelToTfGraphModelSharded(self):
351+
output_dir = os.path.join(self._tmp_dir, 'foo_model')
352+
sequential_model = keras.models.Sequential([
353+
keras.layers.Dense(
354+
3, input_shape=(2,), use_bias=True, kernel_initializer='ones',
355+
name='Dense1')])
356+
h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5')
357+
sequential_model.save(h5_path)
358+
359+
# Do initial conversion without sharding.
360+
converter.dispatch_keras_h5_to_tfjs_graph_model_conversion(
361+
h5_path, output_dir)
362+
weight_files = glob.glob(os.path.join(output_dir, 'group*.bin'))
363+
364+
# Get size of weights in bytes after graph optimizations.
365+
optimized_total_weight = sum([os.path.getsize(f) for f in weight_files])
366+
367+
# Due to the shard size, there ought to be 4 shards after conversion.
368+
weight_shard_size_bytes = int(optimized_total_weight * 0.3)
369+
370+
output_dir = os.path.join(self._tmp_dir, 'sharded_model')
371+
# Convert Keras model again with shard argument set.
372+
converter.dispatch_keras_h5_to_tfjs_graph_model_conversion(
373+
h5_path, output_dir,
374+
weight_shard_size_bytes=weight_shard_size_bytes)
375+
376+
weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin')))
377+
self.assertEqual(len(weight_files), 4)
378+
weight_file_sizes = [os.path.getsize(f) for f in weight_files]
379+
self.assertEqual(sum(weight_file_sizes), optimized_total_weight)
380+
self.assertEqual(weight_file_sizes[0], weight_file_sizes[1])
381+
self.assertEqual(weight_file_sizes[0], weight_file_sizes[2])
382+
self.assertLess(weight_file_sizes[3], weight_file_sizes[0])
383+
321384

322385
class ConvertTfKerasSavedModelTest(tf.test.TestCase):
323386

0 commit comments

Comments
 (0)