@@ -104,7 +104,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
104104 h5_path , output_dir = None ,
105105 quantization_dtype = None ,
106106 skip_op_check = False ,
107- strip_debug_ops = False ):
107+ strip_debug_ops = False ,
108+ weight_shard_size_bytes = 1024 * 1024 * 4 ):
108109 """
109110 Convert a keras HDF5-format model to tfjs GraphModel artifacts.
110111
@@ -117,6 +118,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
117118 (Default: `None`).
118119 skip_op_check: Bool whether to skip the op check.
119120 strip_debug_ops: Bool whether to allow unsupported debug ops.
121+ weight_shard_size_bytes: Shard size (in bytes) of the weight files.
122+ The size of each weight file will be <= this value.
120123 """
121124
122125 if not os .path .exists (h5_path ):
@@ -138,15 +141,17 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
138141 saved_model_tags = 'serve' ,
139142 quantization_dtype = quantization_dtype ,
140143 skip_op_check = skip_op_check ,
141- strip_debug_ops = strip_debug_ops )
144+ strip_debug_ops = strip_debug_ops ,
145+ weight_shard_size_bytes = weight_shard_size_bytes )
142146
143147 # Clean up the temporary SavedModel directory.
144148 shutil .rmtree (temp_savedmodel_dir )
145149
146150
147151def dispatch_keras_saved_model_to_tensorflowjs_conversion (
148152 keras_saved_model_path , output_dir , quantization_dtype = None ,
149- split_weights_by_layer = False ):
153+ split_weights_by_layer = False ,
154+ weight_shard_size_bytes = 1024 * 1024 * 4 ):
150155 """Converts keras model saved in the SavedModel format to tfjs format.
151156
152157 Note that the SavedModel format exists in keras, but not in
@@ -166,6 +171,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
166171 split_weights_by_layer: Whether to split the weights into separate weight
167172 groups (corresponding to separate binary weight files) layer by layer
168173 (Default: `False`).
174+ weight_shard_size_bytes: Shard size (in bytes) of the weight files.
175+ The size of each weight file will be <= this value.
169176 """
170177 with tf .Graph ().as_default (), tf .compat .v1 .Session ():
171178 model = tf .keras .models .load_model (keras_saved_model_path )
@@ -179,7 +186,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
179186 temp_h5_path ,
180187 output_dir ,
181188 quantization_dtype = quantization_dtype ,
182- split_weights_by_layer = split_weights_by_layer )
189+ split_weights_by_layer = split_weights_by_layer ,
190+ weight_shard_size_bytes = weight_shard_size_bytes )
183191
184192 # Delete temporary .h5 file.
185193 os .remove (temp_h5_path )
@@ -321,7 +329,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
321329 output_dir_path ,
322330 quantization_dtype = None ,
323331 skip_op_check = False ,
324- strip_debug_ops = False ):
332+ strip_debug_ops = False ,
333+ weight_shard_size_bytes = 1024 * 1024 * 4 ):
325334 """Converts a TensorFlow.js Layers Model to TensorFlow.js Graph Model.
326335
327336 This conversion often benefits speed of inference, due to the graph
@@ -336,6 +345,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
336345 (Default: `None`).
337346 skip_op_check: Bool whether to skip the op check.
338347 strip_debug_ops: Bool whether to allow unsupported debug ops.
348+ weight_shard_size_bytes: Shard size (in bytes) of the weight files.
349+ The size of each weight file will be <= this value.
339350
340351 Raises:
341352 ValueError, if `config_json_path` is not a path to a valid JSON
@@ -369,7 +380,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
369380 temp_h5_path , output_dir_path ,
370381 quantization_dtype = quantization_dtype ,
371382 skip_op_check = skip_op_check ,
372- strip_debug_ops = strip_debug_ops )
383+ strip_debug_ops = strip_debug_ops ,
384+ weight_shard_size_bytes = weight_shard_size_bytes )
373385
374386 # Clean up temporary HDF5 file.
375387 os .remove (temp_h5_path )
@@ -507,7 +519,7 @@ def get_arg_parser():
507519 type = int ,
508520 default = None ,
509521 help = 'Shard size (in bytes) of the weight files. Currently applicable '
510- 'only to output_format= tfjs_layers_model.' )
522+ 'only when output_format is tfjs_layers_model or tfjs_graph_model .' )
511523 parser .add_argument (
512524 '--output_node_names' ,
513525 type = str ,
@@ -532,14 +544,6 @@ def convert(arguments):
532544 raise ValueError (
533545 'Missing output_path argument. For usage, use the --help flag.' )
534546
535- weight_shard_size_bytes = 1024 * 1024 * 4
536- if args .weight_shard_size_bytes :
537- if args .output_format != common .TFJS_LAYERS_MODEL :
538- raise ValueError (
539- 'The --weight_shard_size_bytes flag is only supported under '
540- 'output_format=tfjs_layers_model.' )
541- weight_shard_size_bytes = args .weight_shard_size_bytes
542-
543547 if args .input_path is None :
544548 raise ValueError (
545549 'Error: The input_path argument must be set. '
@@ -548,6 +552,21 @@ def convert(arguments):
548552 input_format , output_format = _standardize_input_output_formats (
549553 args .input_format , args .output_format )
550554
555+ weight_shard_size_bytes = 1024 * 1024 * 4
556+ if args .weight_shard_size_bytes is not None :
557+ if (output_format not in
558+ (common .TFJS_LAYERS_MODEL , common .TFJS_GRAPH_MODEL )):
559+ raise ValueError (
560+ 'The --weight_shard_size_bytes flag is only supported when '
561+ 'output_format is tfjs_layers_model or tfjs_graph_model.' )
562+
563+ if not (isinstance (args .weight_shard_size_bytes , int ) and
564+ args .weight_shard_size_bytes > 0 ):
565+ raise ValueError (
566+ 'Expected weight_shard_size_bytes to be a positive integer, '
567+ 'but got %s' % args .weight_shard_size_bytes )
568+ weight_shard_size_bytes = args .weight_shard_size_bytes
569+
551570 quantization_dtype = (
552571 quantization .QUANTIZATION_BYTES_TO_DTYPES [args .quantization_bytes ]
553572 if args .quantization_bytes else None )
@@ -570,20 +589,23 @@ def convert(arguments):
570589 dispatch_keras_h5_to_tfjs_layers_model_conversion (
571590 args .input_path , output_dir = args .output_path ,
572591 quantization_dtype = quantization_dtype ,
573- split_weights_by_layer = args .split_weights_by_layer )
592+ split_weights_by_layer = args .split_weights_by_layer ,
593+ weight_shard_size_bytes = weight_shard_size_bytes )
574594 elif (input_format == common .KERAS_MODEL and
575595 output_format == common .TFJS_GRAPH_MODEL ):
576596 dispatch_keras_h5_to_tfjs_graph_model_conversion (
577597 args .input_path , output_dir = args .output_path ,
578598 quantization_dtype = quantization_dtype ,
579599 skip_op_check = args .skip_op_check ,
580- strip_debug_ops = args .strip_debug_ops )
600+ strip_debug_ops = args .strip_debug_ops ,
601+ weight_shard_size_bytes = weight_shard_size_bytes )
581602 elif (input_format == common .KERAS_SAVED_MODEL and
582603 output_format == common .TFJS_LAYERS_MODEL ):
583604 dispatch_keras_saved_model_to_tensorflowjs_conversion (
584605 args .input_path , args .output_path ,
585606 quantization_dtype = quantization_dtype ,
586- split_weights_by_layer = args .split_weights_by_layer )
607+ split_weights_by_layer = args .split_weights_by_layer ,
608+ weight_shard_size_bytes = weight_shard_size_bytes )
587609 elif (input_format == common .TF_SAVED_MODEL and
588610 output_format == common .TFJS_GRAPH_MODEL ):
589611 tf_saved_model_conversion_v2 .convert_tf_saved_model (
@@ -592,7 +614,8 @@ def convert(arguments):
592614 saved_model_tags = args .saved_model_tags ,
593615 quantization_dtype = quantization_dtype ,
594616 skip_op_check = args .skip_op_check ,
595- strip_debug_ops = args .strip_debug_ops )
617+ strip_debug_ops = args .strip_debug_ops ,
618+ weight_shard_size_bytes = weight_shard_size_bytes )
596619 elif (input_format == common .TF_HUB_MODEL and
597620 output_format == common .TFJS_GRAPH_MODEL ):
598621 tf_saved_model_conversion_v2 .convert_tf_hub_module (
@@ -601,7 +624,8 @@ def convert(arguments):
601624 saved_model_tags = args .saved_model_tags ,
602625 quantization_dtype = quantization_dtype ,
603626 skip_op_check = args .skip_op_check ,
604- strip_debug_ops = args .strip_debug_ops )
627+ strip_debug_ops = args .strip_debug_ops ,
628+ weight_shard_size_bytes = weight_shard_size_bytes )
605629 elif (input_format == common .TFJS_LAYERS_MODEL and
606630 output_format == common .KERAS_MODEL ):
607631 dispatch_tensorflowjs_to_keras_h5_conversion (args .input_path ,
@@ -622,14 +646,16 @@ def convert(arguments):
622646 args .input_path , args .output_path ,
623647 quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
624648 skip_op_check = args .skip_op_check ,
625- strip_debug_ops = args .strip_debug_ops )
649+ strip_debug_ops = args .strip_debug_ops ,
650+ weight_shard_size_bytes = weight_shard_size_bytes )
626651 elif (input_format == common .TF_FROZEN_MODEL and
627652 output_format == common .TFJS_GRAPH_MODEL ):
628653 tf_saved_model_conversion_v2 .convert_tf_frozen_model (
629654 args .input_path , args .output_node_names , args .output_path ,
630655 quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
631656 skip_op_check = args .skip_op_check ,
632- strip_debug_ops = args .strip_debug_ops )
657+ strip_debug_ops = args .strip_debug_ops ,
658+ weight_shard_size_bytes = weight_shard_size_bytes )
633659 else :
634660 raise ValueError (
635661 'Unsupported input_format - output_format pair: %s - %s' %
0 commit comments