From 0236b897ff0b6acec2d9888d2e33b81094967958 Mon Sep 17 00:00:00 2001 From: Paul Van Eck Date: Thu, 12 Mar 2020 18:44:58 -0700 Subject: [PATCH] Add weight shard arg to save_keras_model Weight shard size should be customizable through usage of the `tfjs.converters.save_keras_model` Python method. --- .../converters/keras_h5_conversion.py | 8 ++++-- .../converters/keras_h5_conversion_test.py | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py index 5be37141c0b..c25a02869aa 100644 --- a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py +++ b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py @@ -308,7 +308,8 @@ def write_artifacts(topology, json.dump(model_json, f) -def save_keras_model(model, artifacts_dir, quantization_dtype=None): +def save_keras_model(model, artifacts_dir, quantization_dtype=None, + weight_shard_size_bytes=1024 * 1024 * 4): r"""Save a Keras model and its weights in TensorFlow.js format. Args: @@ -327,6 +328,8 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None): If the directory does not exist, this function will attempt to create it. quantization_dtype: An optional numpy dtype to quantize weights to for compression. Only np.uint8 and np.uint16 are supported. + weight_shard_size_bytes: Shard size (in bytes) of the weight files. + The size of each weight file will be <= this value. Raises: ValueError: If `artifacts_dir` already exists as a file (not a directory). @@ -341,5 +344,6 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None): os.makedirs(artifacts_dir) write_artifacts( topology_json, weight_groups, artifacts_dir, - quantization_dtype=quantization_dtype) + quantization_dtype=quantization_dtype, + weight_shard_size_bytes=weight_shard_size_bytes) os.remove(temp_h5_path) diff --git a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion_test.py b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion_test.py index d20b6153b9b..2c002c6bbdd 100644 --- a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function +import glob import json import os import shutil @@ -447,6 +448,30 @@ def testSavedModelSucceedsForExistingDirAndSequential(self): self.assertEqual(1, len(weights_manifest)) self.assertIn('paths', weights_manifest[0]) + def testSavedModelSucceedsForCustomShardSize(self): + model = tf.keras.Sequential([ + tf.keras.layers.Dense(1, input_shape=[2], activation='relu'), + tf.keras.layers.Dense(3, activation='tanh') + ]) + + weights = model.get_weights() + total_weight_bytes = sum(np.size(w) for w in weights) * 4 + + # Due to the shard size, there ought to be 4 shards after conversion. + weight_shard_size_bytes = int(total_weight_bytes * 0.3) + + # Convert Keras model to tfjs_layers_model format. + conversion.save_keras_model(model, self._tmp_dir, + weight_shard_size_bytes=weight_shard_size_bytes) + + weight_files = sorted(glob.glob(os.path.join(self._tmp_dir, 'group*.bin'))) + self.assertEqual(len(weight_files), 4) + weight_file_sizes = [os.path.getsize(f) for f in weight_files] + self.assertEqual(sum(weight_file_sizes), total_weight_bytes) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[1]) + self.assertEqual(weight_file_sizes[0], weight_file_sizes[2]) + self.assertLess(weight_file_sizes[3], weight_file_sizes[0]) + def testSavedModelRaisesErrorIfArtifactsDirExistsAsAFile(self): artifacts_dir = os.path.join(self._tmp_dir, 'artifacts') with open(artifacts_dir, 'wt') as f: