diff --git a/tfjs-converter/README.md b/tfjs-converter/README.md index 98a402451b9..c07a5635e28 100644 --- a/tfjs-converter/README.md +++ b/tfjs-converter/README.md @@ -126,6 +126,7 @@ saved a tf.keras model in the SavedModel format. |`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.| |`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.| |`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.| +|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).| |`--output_node_names`| Only applicable to Frozen Model. The names of the output nodes, separated by commas.| __Note: If you want to convert TensorFlow session bundle, you can install older versions of the tensorflowjs pip package, i.e. `pip install tensorflowjs==0.8.6`.__ diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index deacdfd96e3..fb519b33eaf 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -104,7 +104,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( h5_path, output_dir=None, quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """ Convert a keras HDF5-format model to tfjs GraphModel artifacts. @@ -117,6 +118,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( (Default: `None`). skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ if not os.path.exists(h5_path): @@ -138,7 +141,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( saved_model_tags='serve', quantization_dtype=quantization_dtype, skip_op_check=skip_op_check, - strip_debug_ops=strip_debug_ops) + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) # Clean up the temporary SavedModel directory. shutil.rmtree(temp_savedmodel_dir) @@ -146,7 +150,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( def dispatch_keras_saved_model_to_tensorflowjs_conversion( keras_saved_model_path, output_dir, quantization_dtype=None, - split_weights_by_layer=False): + split_weights_by_layer=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Converts keras model saved in the SavedModel format to tfjs format. Note that the SavedModel format exists in keras, but not in @@ -166,6 +171,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion( split_weights_by_layer: Whether to split the weights into separate weight groups (corresponding to separate binary weight files) layer by layer (Default: `False`). + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ with tf.Graph().as_default(), tf.compat.v1.Session(): model = tf.keras.models.load_model(keras_saved_model_path) @@ -179,7 +186,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion( temp_h5_path, output_dir, quantization_dtype=quantization_dtype, - split_weights_by_layer=split_weights_by_layer) + split_weights_by_layer=split_weights_by_layer, + weight_shard_size_bytes=weight_shard_size_bytes) # Delete temporary .h5 file. os.remove(temp_h5_path) @@ -321,7 +329,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( output_dir_path, quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Converts a TensorFlow.js Layers Model to TensorFlow.js Graph Model. This conversion often benefits speed of inference, due to the graph @@ -336,6 +345,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( (Default: `None`). skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. Raises: ValueError, if `config_json_path` is not a path to a valid JSON @@ -369,7 +380,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( temp_h5_path, output_dir_path, quantization_dtype=quantization_dtype, skip_op_check=skip_op_check, - strip_debug_ops=strip_debug_ops) + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) # Clean up temporary HDF5 file. os.remove(temp_h5_path) @@ -507,7 +519,7 @@ def get_arg_parser(): type=int, default=None, help='Shard size (in bytes) of the weight files. Currently applicable ' - 'only to output_format=tfjs_layers_model.') + 'only when output_format is tfjs_layers_model or tfjs_graph_model.') parser.add_argument( '--output_node_names', type=str, @@ -532,14 +544,6 @@ def convert(arguments): raise ValueError( 'Missing output_path argument. For usage, use the --help flag.') - weight_shard_size_bytes = 1024 * 1024 * 4 - if args.weight_shard_size_bytes: - if args.output_format != common.TFJS_LAYERS_MODEL: - raise ValueError( - 'The --weight_shard_size_bytes flag is only supported under ' - 'output_format=tfjs_layers_model.') - weight_shard_size_bytes = args.weight_shard_size_bytes - if args.input_path is None: raise ValueError( 'Error: The input_path argument must be set. ' @@ -548,6 +552,21 @@ def convert(arguments): input_format, output_format = _standardize_input_output_formats( args.input_format, args.output_format) + weight_shard_size_bytes = 1024 * 1024 * 4 + if args.weight_shard_size_bytes is not None: + if (output_format not in + (common.TFJS_LAYERS_MODEL, common.TFJS_GRAPH_MODEL)): + raise ValueError( + 'The --weight_shard_size_bytes flag is only supported when ' + 'output_format is tfjs_layers_model or tfjs_graph_model.') + + if not (isinstance(args.weight_shard_size_bytes, int) and + args.weight_shard_size_bytes > 0): + raise ValueError( + 'Expected weight_shard_size_bytes to be a positive integer, ' + 'but got %s' % args.weight_shard_size_bytes) + weight_shard_size_bytes = args.weight_shard_size_bytes + quantization_dtype = ( quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes] if args.quantization_bytes else None) @@ -570,20 +589,23 @@ def convert(arguments): dispatch_keras_h5_to_tfjs_layers_model_conversion( args.input_path, output_dir=args.output_path, quantization_dtype=quantization_dtype, - split_weights_by_layer=args.split_weights_by_layer) + split_weights_by_layer=args.split_weights_by_layer, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.KERAS_MODEL and output_format == common.TFJS_GRAPH_MODEL): dispatch_keras_h5_to_tfjs_graph_model_conversion( args.input_path, output_dir=args.output_path, quantization_dtype=quantization_dtype, skip_op_check=args.skip_op_check, - strip_debug_ops=args.strip_debug_ops) + strip_debug_ops=args.strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.KERAS_SAVED_MODEL and output_format == common.TFJS_LAYERS_MODEL): dispatch_keras_saved_model_to_tensorflowjs_conversion( args.input_path, args.output_path, quantization_dtype=quantization_dtype, - split_weights_by_layer=args.split_weights_by_layer) + split_weights_by_layer=args.split_weights_by_layer, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TF_SAVED_MODEL and output_format == common.TFJS_GRAPH_MODEL): tf_saved_model_conversion_v2.convert_tf_saved_model( @@ -592,7 +614,8 @@ def convert(arguments): saved_model_tags=args.saved_model_tags, quantization_dtype=quantization_dtype, skip_op_check=args.skip_op_check, - strip_debug_ops=args.strip_debug_ops) + strip_debug_ops=args.strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TF_HUB_MODEL and output_format == common.TFJS_GRAPH_MODEL): tf_saved_model_conversion_v2.convert_tf_hub_module( @@ -601,7 +624,8 @@ def convert(arguments): saved_model_tags=args.saved_model_tags, quantization_dtype=quantization_dtype, skip_op_check=args.skip_op_check, - strip_debug_ops=args.strip_debug_ops) + strip_debug_ops=args.strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TFJS_LAYERS_MODEL and output_format == common.KERAS_MODEL): dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path, @@ -622,14 +646,16 @@ def convert(arguments): args.input_path, args.output_path, quantization_dtype=_parse_quantization_bytes(args.quantization_bytes), skip_op_check=args.skip_op_check, - strip_debug_ops=args.strip_debug_ops) + strip_debug_ops=args.strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TF_FROZEN_MODEL and output_format == common.TFJS_GRAPH_MODEL): tf_saved_model_conversion_v2.convert_tf_frozen_model( args.input_path, args.output_node_names, args.output_path, quantization_dtype=_parse_quantization_bytes(args.quantization_bytes), skip_op_check=args.skip_op_check, - strip_debug_ops=args.strip_debug_ops) + strip_debug_ops=args.strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) else: raise ValueError( 'Unsupported input_format - output_format pair: %s - %s' % diff --git a/tfjs-converter/python/tensorflowjs/converters/converter_test.py b/tfjs-converter/python/tensorflowjs/converters/converter_test.py index fa8049a7887..2870a8cb37c 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter_test.py @@ -151,6 +151,35 @@ def testConvertSavedKerasModelSplitByLayer(self): self.assertIsInstance(output_json['weightsManifest'], list) self.assertTrue(glob.glob(os.path.join(self._tmp_dir, 'group*-*'))) + def testConvertSavedKerasModeltoTfLayersModelSharded(self): + with tf.Graph().as_default(), tf.compat.v1.Session(): + sequential_model = keras.models.Sequential([ + keras.layers.Dense( + 3, input_shape=(2,), use_bias=True, kernel_initializer='ones', + name='Dense1')]) + h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5') + sequential_model.save(h5_path) + + weights = sequential_model.get_weights() + total_weight_bytes = sum(np.size(w) for w in weights) * 4 + + # Due to the shard size, there ought to be 4 shards after conversion. + weight_shard_size_bytes = int(total_weight_bytes * 0.3) + + # Convert Keras model to tfjs_layers_model format. + output_dir = os.path.join(self._tmp_dir, 'sharded_tfjs') + converter.dispatch_keras_h5_to_tfjs_layers_model_conversion( + h5_path, output_dir, + weight_shard_size_bytes=weight_shard_size_bytes) + + weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin'))) + self.assertEqual(len(weight_files), 4) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + self.assertEqual(sum(weight_file_sizes), total_weight_bytes) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[1]) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[2]) + self.assertLess(weight_file_sizes[3], weight_file_sizes[0]) + def testConvertWeightsFromSequentialModel(self): with tf.Graph().as_default(), tf.compat.v1.Session(): sequential_model = keras.models.Sequential([ @@ -318,6 +347,40 @@ def testConvertKerasModelToTfGraphModel(self): tf.__version__) self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*'))) + def testConvertKerasModelToTfGraphModelSharded(self): + output_dir = os.path.join(self._tmp_dir, 'foo_model') + sequential_model = keras.models.Sequential([ + keras.layers.Dense( + 3, input_shape=(2,), use_bias=True, kernel_initializer='ones', + name='Dense1')]) + h5_path = os.path.join(self._tmp_dir, 'SequentialModel.h5') + sequential_model.save(h5_path) + + # Do initial conversion without sharding. + converter.dispatch_keras_h5_to_tfjs_graph_model_conversion( + h5_path, output_dir) + weight_files = glob.glob(os.path.join(output_dir, 'group*.bin')) + + # Get size of weights in bytes after graph optimizations. + optimized_total_weight = sum([os.path.getsize(f) for f in weight_files]) + + # Due to the shard size, there ought to be 4 shards after conversion. + weight_shard_size_bytes = int(optimized_total_weight * 0.3) + + output_dir = os.path.join(self._tmp_dir, 'sharded_model') + # Convert Keras model again with shard argument set. + converter.dispatch_keras_h5_to_tfjs_graph_model_conversion( + h5_path, output_dir, + weight_shard_size_bytes=weight_shard_size_bytes) + + weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin'))) + self.assertEqual(len(weight_files), 4) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + self.assertEqual(sum(weight_file_sizes), optimized_total_weight) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[1]) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[2]) + self.assertLess(weight_file_sizes[3], weight_file_sizes[0]) + class ConvertTfKerasSavedModelTest(tf.test.TestCase): diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 3c512e2471c..81ec8ba050e 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -113,7 +113,8 @@ def _run_grappler(config, graph_def, graph, signature_def): def optimize_graph(graph, signature_def, output_graph, tf_version, quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Takes a Python Graph object and optimizes the graph. Args: @@ -125,6 +126,8 @@ def optimize_graph(graph, signature_def, output_graph, compression. Only np.uint8 and np.uint16 are supported. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ # Add a collection 'train_op' so that Grappler knows the outputs. @@ -191,7 +194,7 @@ def optimize_graph(graph, signature_def, output_graph, extract_weights( optimized_graph, output_graph, tf_version, - signature_def, quantization_dtype) + signature_def, quantization_dtype, weight_shard_size_bytes) return optimize_graph @@ -199,7 +202,8 @@ def extract_weights(graph_def, output_graph, tf_version, signature_def, - quantization_dtype=None): + quantization_dtype=None, + weight_shard_size_bytes=1024 * 1024 * 4): """Takes a Python GraphDef object and extract the weights. Args: @@ -209,6 +213,8 @@ def extract_weights(graph_def, signature_def: the SignatureDef of the inference graph. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ constants = [node for node in graph_def.node if node.op == 'Const'] const_inputs = {} @@ -234,7 +240,8 @@ def extract_weights(graph_def, write_artifacts(MessageToDict(graph_def), [const_manifest], output_graph, tf_version, signature_def, - quantization_dtype=quantization_dtype) + quantization_dtype=quantization_dtype, + weight_shard_size_bytes=weight_shard_size_bytes) def write_artifacts(topology, @@ -242,7 +249,8 @@ def write_artifacts(topology, output_graph, tf_version, signature_def, - quantization_dtype=None): + quantization_dtype=None, + weight_shard_size_bytes=1024 * 1024 * 4): """Writes weights and topology to the output_dir. If `topology` is Falsy (e.g., `None`), only emit weights to output_dir. @@ -256,7 +264,10 @@ def write_artifacts(topology, signature_def: the SignatureDef of the inference graph. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ + model_json = { common.FORMAT_KEY: common.TFJS_GRAPH_MODEL_FORMAT, # TODO(piyu): Add tensorflow version below by using `meta_info_def`. @@ -269,7 +280,8 @@ def write_artifacts(topology, model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None weights_manifest = write_weights.write_weights( weights, os.path.dirname(output_graph), write_manifest=False, - quantization_dtype=quantization_dtype) + quantization_dtype=quantization_dtype, + shard_size_bytes=weight_shard_size_bytes) assert isinstance(weights_manifest, list) model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest @@ -347,7 +359,8 @@ def convert_tf_frozen_model(frozen_model_path, output_node_names, output_dir, quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Convert frozen model and check the model compatibility with Tensorflow.js. Optimize and convert the model to Tensorflow.js format, when the model passes the compatiblity check. @@ -362,6 +375,8 @@ def convert_tf_frozen_model(frozen_model_path, compression. Only np.uint8 and np.uint16 are supported. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ if not os.path.exists(output_dir): @@ -376,14 +391,16 @@ def convert_tf_frozen_model(frozen_model_path, output_graph, tf.__version__, quantization_dtype=quantization_dtype, skip_op_check=skip_op_check, - strip_debug_ops=strip_debug_ops) + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) def convert_tf_saved_model(saved_model_dir, output_dir, signature_def='serving_default', saved_model_tags='serve', quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Freeze the SavedModel and check the model compatibility with Tensorflow.js. Optimize and convert the model to Tensorflow.js format, when the model passes @@ -403,6 +420,8 @@ def convert_tf_saved_model(saved_model_dir, compression. Only np.uint8 and np.uint16 are supported. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ if signature_def is None: signature_def = 'serving_default' @@ -437,7 +456,8 @@ def convert_tf_saved_model(saved_model_dir, output_graph, model.tensorflow_version, quantization_dtype=quantization_dtype, skip_op_check=skip_op_check, - strip_debug_ops=strip_debug_ops) + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) def load_and_initialize_hub_module(module_path, signature='default'): """Loads graph of a TF-Hub module and initializes it into a session. @@ -487,7 +507,8 @@ def load_and_initialize_hub_module(module_path, signature='default'): def convert_tf_hub_module_v1(module_path, output_dir, signature='default', quantization_dtype=None, - skip_op_check=False, strip_debug_ops=False): + skip_op_check=False, strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Freeze the TF-Hub module and check compatibility with Tensorflow.js. Optimize and convert the TF-Hub module to Tensorflow.js format, if it passes @@ -502,6 +523,8 @@ def convert_tf_hub_module_v1(module_path, output_dir, signature: string Signature to load. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ if signature is None: @@ -540,7 +563,9 @@ def convert_tf_hub_module_v1(module_path, output_dir, optimize_graph(frozen_graph, signature, output_graph, tf.__version__, quantization_dtype=quantization_dtype, - skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops) + skip_op_check=skip_op_check, + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) finally: # Clean up the temp files. if os.path.exists(frozen_file): @@ -550,7 +575,8 @@ def convert_tf_hub_module_v1(module_path, output_dir, def convert_tf_hub_module(module_handle, output_dir, signature='default', saved_model_tags='serve', quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False): + strip_debug_ops=False, + weight_shard_size_bytes=1024 * 1024 * 4): """Conversion for TF Hub modules V1 and V2. See convert_tf_hub_module and convert_tf_saved_model. @@ -565,6 +591,8 @@ def convert_tf_hub_module(module_handle, output_dir, saved_model_tags: tags of the GraphDef to load. Defaults to ''. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. """ module_path = hub.resolve(module_handle) # TODO(vbardiovskyg): We can remove this v1 code path once loading of all v1 @@ -573,7 +601,8 @@ def convert_tf_hub_module(module_handle, output_dir, if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)): print("Loading the module using TF 1.X interface from %s." % module_path) convert_tf_hub_module_v1(module_path, output_dir, signature, - quantization_dtype, skip_op_check, strip_debug_ops) + quantization_dtype, skip_op_check, strip_debug_ops, + weight_shard_size_bytes) else: print("Loading the module using TF 2.X interface from %s." % module_path) if signature is None: @@ -584,4 +613,5 @@ def convert_tf_hub_module(module_handle, output_dir, saved_model_tags=saved_model_tags, quantization_dtype=quantization_dtype, skip_op_check=skip_op_check, - strip_debug_ops=strip_debug_ops) + strip_debug_ops=strip_debug_ops, + weight_shard_size_bytes=weight_shard_size_bytes) diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index e8bb8a70125..47ee320369b 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -678,6 +678,34 @@ def test_convert_saved_model_with_control_flow(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_saved_model_sharded(self): + self._create_saved_model() + model_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + + # Do initial conversion without sharding. + tf_saved_model_conversion_v2.convert_tf_saved_model(model_path, tfjs_path) + weight_files = glob.glob(os.path.join(tfjs_path, 'group*.bin')) + + # Get size of weights in bytes after graph optimizations. + optimized_total_weight = sum([os.path.getsize(f) for f in weight_files]) + + # Due to the shard size, there ought to be 2 shards after conversion. + weight_shard_size_bytes = int(optimized_total_weight * 0.8) + + tfjs_path = os.path.join(self._tmp_dir, 'sharded_model') + # Convert Saved Model again with shard argument set. + tf_saved_model_conversion_v2.convert_tf_saved_model( + model_path, tfjs_path, + weight_shard_size_bytes=weight_shard_size_bytes) + + weight_files = sorted(glob.glob(os.path.join(tfjs_path, 'group*.bin'))) + self.assertEqual(len(weight_files), 2) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + + self.assertEqual(sum(weight_file_sizes), optimized_total_weight) + self.assertLess(weight_file_sizes[1], weight_file_sizes[0]) + def test_optimizer_add_unsupported_op(self): self._create_unsupported_saved_model() with self.assertRaisesRegexp( # pylint: disable=deprecated-method @@ -770,6 +798,35 @@ def test_convert_hub_module_v1(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_hub_module_v1_sharded(self): + self._create_hub_module() + module_path = os.path.join(self._tmp_dir, HUB_MODULE_DIR) + tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + + # Do initial conversion without sharding. + tf_saved_model_conversion_v2.convert_tf_hub_module(module_path, tfjs_path) + weight_files = glob.glob(os.path.join(tfjs_path, 'group*.bin')) + + # Get size of weights in bytes after graph optimizations. + optimized_total_weight = sum([os.path.getsize(f) for f in weight_files]) + + # Due to the shard size, there ought to be 3 shards after conversion. + weight_shard_size_bytes = int(optimized_total_weight * 0.4) + + tfjs_path = os.path.join(self._tmp_dir, 'sharded_model') + # Convert Hub model again with shard argument set. + tf_saved_model_conversion_v2.convert_tf_hub_module( + module_path, tfjs_path, + weight_shard_size_bytes=weight_shard_size_bytes) + + weight_files = sorted(glob.glob(os.path.join(tfjs_path, 'group*.bin'))) + self.assertEqual(len(weight_files), 3) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + + self.assertEqual(sum(weight_file_sizes), optimized_total_weight) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[1]) + self.assertLess(weight_file_sizes[2], weight_file_sizes[0]) + def test_convert_hub_module_v2(self): self._create_saved_model() module_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard.py b/tfjs-converter/python/tensorflowjs/converters/wizard.py index 457e0de0eac..cfeb9a0d312 100644 --- a/tfjs-converter/python/tensorflowjs/converters/wizard.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard.py @@ -482,8 +482,15 @@ def run(dryrun): 'name': common.WEIGHT_SHARD_SIZE_BYTES, 'message': 'Please enter shard size (in bytes) of the weight files?', 'default': str(4 * 1024 * 1024), - 'when': lambda answers: value_in_list(answers, common.OUTPUT_FORMAT, - (common.TFJS_LAYERS_MODEL)) + 'validate': + lambda size: ('Please enter a positive integer' if not + (size.isdigit() and int(size) > 0) else True), + 'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT, + (common.TFJS_LAYERS_MODEL, + common.TFJS_GRAPH_MODEL)) or + value_in_list(answers, common.INPUT_FORMAT, + (common.TF_SAVED_MODEL, + common.TF_HUB_MODEL))) }, { 'type': 'confirm', diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py index 752eac58918..4b61bac5eb0 100644 --- a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py @@ -185,6 +185,7 @@ def testGenerateCommandForSavedModel(self): 'saved_model_tags': 'test', 'signature_name': 'test_default', 'quantization_bytes': 2, + 'weight_shard_size_bytes': '4194304', 'skip_op_check': False, 'strip_debug_ops': True, 'output_path': 'tmp/web_model'} @@ -192,6 +193,7 @@ def testGenerateCommandForSavedModel(self): self.assertEqual(['--input_format=tf_saved_model', '--quantization_bytes=2', '--saved_model_tags=test', '--signature_name=test_default', '--strip_debug_ops=True', + '--weight_shard_size_bytes=4194304', 'tmp/saved_model', 'tmp/web_model'], wizard.generate_arguments(options)) @@ -201,6 +203,7 @@ def testGenerateCommandForKerasSavedModel(self): 'input_path': 'tmp/saved_model', 'saved_model_tags': 'test', 'signature_name': 'test_default', + 'weight_shard_size_bytes': '100', 'quantization_bytes': 1, 'skip_op_check': True, 'strip_debug_ops': False, @@ -210,17 +213,20 @@ def testGenerateCommandForKerasSavedModel(self): '--output_format=tfjs_layers_model', '--quantization_bytes=1', '--saved_model_tags=test', '--signature_name=test_default', '--skip_op_check', - '--strip_debug_ops=False', 'tmp/saved_model', - 'tmp/web_model'], + '--strip_debug_ops=False', + '--weight_shard_size_bytes=100', + 'tmp/saved_model', 'tmp/web_model'], wizard.generate_arguments(options)) def testGenerateCommandForKerasModel(self): options = {'input_format': 'keras', 'input_path': 'tmp/model.HD5', + 'weight_shard_size_bytes': '100', 'quantization_bytes': 1, 'output_path': 'tmp/web_model'} self.assertEqual(['--input_format=keras', '--quantization_bytes=1', + '--weight_shard_size_bytes=100', 'tmp/model.HD5', 'tmp/web_model'], wizard.generate_arguments(options)) @@ -229,11 +235,14 @@ def testGenerateCommandForLayerModel(self): 'output_format': 'keras', 'input_path': 'tmp/model.json', 'quantization_bytes': 1, + 'weight_shard_size_bytes': '100', 'output_path': 'tmp/web_model'} self.assertEqual(['--input_format=tfjs_layers_model', '--output_format=keras', - '--quantization_bytes=1', 'tmp/model.json', + '--quantization_bytes=1', + '--weight_shard_size_bytes=100', + 'tmp/model.json', 'tmp/web_model'], wizard.generate_arguments(options)) diff --git a/tfjs-converter/python/tensorflowjs/write_weights.py b/tfjs-converter/python/tensorflowjs/write_weights.py index 1c9fceb55db..692606ccb37 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights.py +++ b/tfjs-converter/python/tensorflowjs/write_weights.py @@ -397,7 +397,7 @@ def _assert_weight_groups_valid(weight_groups): def _assert_shard_size_bytes_valid(shard_size_bytes): - if shard_size_bytes < 0: + if shard_size_bytes <= 0: raise ValueError( 'shard_size_bytes must be greater than 0, but got %s' % shard_size_bytes) diff --git a/tfjs-converter/python/test_pip_package.py b/tfjs-converter/python/test_pip_package.py index 65d11b509b1..644a182ed08 100644 --- a/tfjs-converter/python/test_pip_package.py +++ b/tfjs-converter/python/test_pip_package.py @@ -520,6 +520,41 @@ def testConvertTFSavedModelWithCommandLineWorks(self): # Check the content of the output directory. self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*'))) + def testConvertTFSavedModelIntoShardedWeights(self): + output_dir = os.path.join(self._tmp_dir, 'tfjs_model') + # Do initial conversion without sharding. + process = subprocess.Popen([ + 'tensorflowjs_converter', '--input_format', 'tf_saved_model', + '--output_format', 'tfjs_graph_model', + self.tf_saved_model_dir, output_dir + ]) + process.communicate() + self.assertEqual(0, process.returncode) + + weight_files = glob.glob(os.path.join(output_dir, 'group*.bin')) + + # Get size of weights in bytes after graph optimizations. + optimized_total_weight = sum([os.path.getsize(f) for f in weight_files]) + # Due to the shard size, there ought to be 2 shards after conversion. + weight_shard_size_bytes = int(optimized_total_weight * 0.8) + + output_dir = os.path.join(self._tmp_dir, 'sharded_model') + # Convert Saved Model again with shard argument set. + process = subprocess.Popen([ + 'tensorflowjs_converter', '--input_format', 'tf_saved_model', + '--output_format', 'tfjs_graph_model', + '--weight_shard_size_bytes', str(weight_shard_size_bytes), + self.tf_saved_model_dir, output_dir + ]) + process.communicate() + self.assertEqual(0, process.returncode) + + weight_files = sorted(glob.glob(os.path.join(output_dir, 'group*.bin'))) + self.assertEqual(len(weight_files), 2) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + self.assertEqual(sum(weight_file_sizes), optimized_total_weight) + self.assertLess(weight_file_sizes[1], weight_file_sizes[0]) + def testConvertTFFrozenModelWithCommandLineWorks(self): output_dir = os.path.join(self._tmp_dir) frozen_file = os.path.join(self.tf_frozen_model_dir, 'frozen.pb')