Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def write_artifacts(topology,
json.dump(model_json, f)


def save_keras_model(model, artifacts_dir, quantization_dtype=None):
def save_keras_model(model, artifacts_dir, quantization_dtype=None,
weight_shard_size_bytes=1024 * 1024 * 4):
r"""Save a Keras model and its weights in TensorFlow.js format.

Args:
Expand All @@ -327,6 +328,8 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None):
If the directory does not exist, this function will attempt to create it.
quantization_dtype: An optional numpy dtype to quantize weights to for
compression. Only np.uint8 and np.uint16 are supported.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.

Raises:
ValueError: If `artifacts_dir` already exists as a file (not a directory).
Expand All @@ -341,5 +344,6 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None):
os.makedirs(artifacts_dir)
write_artifacts(
topology_json, weight_groups, artifacts_dir,
quantization_dtype=quantization_dtype)
quantization_dtype=quantization_dtype,
weight_shard_size_bytes=weight_shard_size_bytes)
os.remove(temp_h5_path)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function

import glob
import json
import os
import shutil
Expand Down Expand Up @@ -447,6 +448,30 @@ def testSavedModelSucceedsForExistingDirAndSequential(self):
self.assertEqual(1, len(weights_manifest))
self.assertIn('paths', weights_manifest[0])

def testSavedModelSucceedsForCustomShardSize(self):
model = tf.keras.Sequential([
tf.keras.layers.Dense(1, input_shape=[2], activation='relu'),
tf.keras.layers.Dense(3, activation='tanh')
])

weights = model.get_weights()
total_weight_bytes = sum(np.size(w) for w in weights) * 4

# Due to the shard size, there ought to be 4 shards after conversion.
weight_shard_size_bytes = int(total_weight_bytes * 0.3)

# Convert Keras model to tfjs_layers_model format.
conversion.save_keras_model(model, self._tmp_dir,
weight_shard_size_bytes=weight_shard_size_bytes)

weight_files = sorted(glob.glob(os.path.join(self._tmp_dir, 'group*.bin')))
self.assertEqual(len(weight_files), 4)
weight_file_sizes = [os.path.getsize(f) for f in weight_files]
self.assertEqual(sum(weight_file_sizes), total_weight_bytes)
self.assertEqual(weight_file_sizes[0], weight_file_sizes[1])
self.assertEqual(weight_file_sizes[0], weight_file_sizes[2])
self.assertLess(weight_file_sizes[3], weight_file_sizes[0])

def testSavedModelRaisesErrorIfArtifactsDirExistsAsAFile(self):
artifacts_dir = os.path.join(self._tmp_dir, 'artifacts')
with open(artifacts_dir, 'wt') as f:
Expand Down