diff --git a/tfjs-converter/README.md b/tfjs-converter/README.md
index 0f810052dd2..40fa234ce77 100644
--- a/tfjs-converter/README.md
+++ b/tfjs-converter/README.md
@@ -154,7 +154,10 @@ saved a tf.keras model in the SavedModel format.
|`--saved_model_tags` | Only applicable to SavedModel conversion. Tags of the MetaGraphDef to load, in comma separated format. Defaults to `serve`.|
|`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
-|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
+|`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
+|`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
+|`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
+|`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
|`--output_node_names`| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|
@@ -216,7 +219,7 @@ purposes:
tensorflowjs_converter \
--input_format tfjs_layers_model \
--output_format tfjs_layers_model \
- --quantization_bytes 2 \
+ --quantize_uint16 \
original_model/model.json
quantized_model/
```
@@ -380,18 +383,33 @@ browser to cache them automatically. If the model architecture is less than 4MB
__4. Can I quantize the weights over the wire?__
-Yes, you can use the --quantization_bytes option to compress int32/float32 to 1
-or 2 bytes. Here is
-an example of 8-bit quantization:
+Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress
+weights with 1 byte integer quantization (`uint8`) or 2 byte integer
+(`uint16`)/float (`float16`) quantization.
+Quantizing to float16 may provide better accuracy over
+2 byte affine integer scaling (`uint16`). 1-byte affine quantization,
+i.e., `uint8` provides a 4x size reduction at the cost of accuracy.
+For example, we can quantize our MobileNet model using float16 quantization:
```
-tensorflowjs_converter \
+tensorflowjs_converter
+ --quantize_float16 \
--input_format=tf_hub \
- --quantization_bytes=1
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
```
+You can also quantize specific weights as well as weight groupings using
+a wildcard replacement. For example,
+```
+tensorflowjs_converter
+ --quantize_float16="conv/*/weights"
+```
+which will quantize all weights that match the pattern conv/*/weights.
+This will exclude biases and any weights that don't begin with conv/.
+This can be a powerful tool to reduce model size while trying to maximize
+performance.
+
__5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__
The time of first call also includes the compilation time of WebGL shader
diff --git a/tfjs-converter/python/tensorflowjs/converters/common.py b/tfjs-converter/python/tensorflowjs/converters/common.py
index 5aa062d1fe2..2775d5aedb6 100644
--- a/tfjs-converter/python/tensorflowjs/converters/common.py
+++ b/tfjs-converter/python/tensorflowjs/converters/common.py
@@ -50,6 +50,9 @@
SIGNATURE_NAME = 'signature_name'
SAVED_MODEL_TAGS = 'saved_model_tags'
QUANTIZATION_BYTES = 'quantization_bytes'
+QUANTIZATION_TYPE_FLOAT16 = 'quantize_float16'
+QUANTIZATION_TYPE_UINT8 = 'quantize_uint8'
+QUANTIZATION_TYPE_UINT16 = 'quantize_uint16'
SPLIT_WEIGHTS_BY_LAYER = 'split_weights_by_layer'
VERSION = 'version'
SKIP_OP_CHECK = 'skip_op_check'
diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py
index 3858e666591..ed682ac1b73 100644
--- a/tfjs-converter/python/tensorflowjs/converters/converter.py
+++ b/tfjs-converter/python/tensorflowjs/converters/converter.py
@@ -26,7 +26,6 @@
import tempfile
import h5py
-import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf
@@ -39,7 +38,7 @@
def dispatch_keras_h5_to_tfjs_layers_model_conversion(
- h5_path, output_dir=None, quantization_dtype=None,
+ h5_path, output_dir=None, quantization_dtype_map=None,
split_weights_by_layer=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts a Keras HDF5 saved-model file to TensorFlow.js format.
@@ -56,8 +55,8 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion(
output_dir: Output directory to which the TensorFlow.js-format model JSON
file and weights files will be written. If the directory does not exist,
it will be created.
- quantization_dtype: The quantized data type to store the weights in
- (Default: `None`).
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
split_weights_by_layer: Whether to split the weights into separate weight
groups (corresponding to separate binary weight files) layer by layer
(Default: `False`).
@@ -94,7 +93,7 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion(
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
conversion.write_artifacts(
- model_json, groups, output_dir, quantization_dtype,
+ model_json, groups, output_dir, quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
return model_json, groups
@@ -102,7 +101,7 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion(
def dispatch_keras_h5_to_tfjs_graph_model_conversion(
h5_path, output_dir=None,
- quantization_dtype=None,
+ quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
@@ -115,8 +114,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
keras or tf.keras.
output_dir: The destination to which the tfjs GraphModel artifacts will be
written.
- quantization_dtype: The quantized data type to store the weights in
- (Default: `None`).
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -140,7 +139,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
temp_savedmodel_dir, output_dir,
signature_def='serving_default',
saved_model_tags='serve',
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
@@ -151,7 +150,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
def dispatch_keras_saved_model_to_tensorflowjs_conversion(
- keras_saved_model_path, output_dir, quantization_dtype=None,
+ keras_saved_model_path, output_dir, quantization_dtype_map=None,
split_weights_by_layer=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts keras model saved in the SavedModel format to tfjs format.
@@ -168,8 +167,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
output_dir: Output directory to which the TensorFlow.js-format model JSON
file and weights files will be written. If the directory does not exist,
it will be created.
- quantization_dtype: The quantized data type to store the weights in
- (Default: `None`).
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
split_weights_by_layer: Whether to split the weights into separate weight
groups (corresponding to separate binary weight files) layer by layer
(Default: `False`).
@@ -187,7 +186,7 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
dispatch_keras_h5_to_tfjs_layers_model_conversion(
temp_h5_path,
output_dir,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -271,7 +270,7 @@ def dispatch_tensorflowjs_to_keras_saved_model_conversion(
def dispatch_tensorflowjs_to_tensorflowjs_conversion(
config_json_path,
output_dir_path,
- quantization_dtype=None,
+ quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts a Keras Model from tensorflowjs format to H5.
@@ -280,8 +279,8 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
topology and weights manifest, in tensorflowjs format.
output_dir_path: Path to output directory in which the result of the
conversion will be saved.
- quantization_dtype: The quantized data type to store the weights in
- (Default: `None`).
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
@@ -318,7 +317,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
with tf.Graph().as_default(), tf.compat.v1.Session():
dispatch_keras_h5_to_tfjs_layers_model_conversion(
temp_h5_path, output_dir_path,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
# TODO(cais): Support weight quantization.
@@ -329,7 +328,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
config_json_path,
output_dir_path,
- quantization_dtype=None,
+ quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
@@ -343,8 +342,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
topology and weights manifest, in tensorflowjs format.
output_dir_path: Path to output directory in which the result of the
conversion will be saved.
- quantization_dtype: The quantized data type to store the weights in
- (Default: `None`).
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -380,7 +379,7 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
model.save(temp_h5_path)
dispatch_keras_h5_to_tfjs_graph_model_conversion(
temp_h5_path, output_dir_path,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -416,16 +415,32 @@ def _standardize_input_output_formats(input_format, output_format):
return (input_format, output_format)
+def _parse_quantization_dtype_map(float16, uint8, uint16, quantization_bytes):
+ quantization_dtype_map = {}
-def _parse_quantization_bytes(quantization_bytes):
- if quantization_bytes is None:
- return None
- elif quantization_bytes == 1:
- return np.uint8
- elif quantization_bytes == 2:
- return np.uint16
- else:
- raise ValueError('Unsupported quantization bytes: %s' % quantization_bytes)
+ if quantization_bytes:
+ print(
+ 'Warning: --quantization_bytes will be deprecated in a future release\n'
+ 'Please consider using --quantize_uint8, --quantize_uint16, '
+ '--quantize_float16.', file=sys.stderr)
+ if float16 is not None or uint8 is not None or uint16 is not None:
+ raise ValueError(
+ '--quantization_bytes cannot be used with the new quantization flags')
+
+ dtype = quantization.QUANTIZATION_BYTES_TO_DTYPES[quantization_bytes]
+ quantization_dtype_map[dtype] = True
+
+ if float16 is not None:
+ quantization_dtype_map[quantization.QUANTIZATION_DTYPE_FLOAT16] = \
+ float16.split(',') if isinstance(float16, str) else float16
+ if uint8 is not None:
+ quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT8] = \
+ uint8.split(',') if isinstance(uint8, str) else uint8
+ if uint16 is not None:
+ quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT16] = \
+ uint16.split(',') if isinstance(uint16, str) else uint16
+
+ return quantization_dtype_map
def get_arg_parser():
"""
@@ -488,13 +503,43 @@ def get_arg_parser():
help='Tags of the MetaGraphDef to load, in comma separated string '
'format. Defaults to "serve". Applicable only if input format is '
'"tf_saved_model".')
+ parser.add_argument(
+ '--%s' % common.QUANTIZATION_TYPE_FLOAT16,
+ type=str,
+ default=None,
+ const=True,
+ nargs='?',
+ help='Comma separated list of node names to apply float16 quantization. '
+ 'You can also use wildcard symbol (*) to apply quantization to multiple '
+ 'nodes (e.g., conv/*/weights). When the flag is provided without any '
+ 'nodes the default behavior will match all nodes.')
+ parser.add_argument(
+ '--%s' % common.QUANTIZATION_TYPE_UINT8,
+ type=str,
+ default=None,
+ const=True,
+ nargs='?',
+ help='Comma separated list of node names to apply 1-byte affine '
+ 'quantization. You can also use wildcard symbol (*) to apply '
+ 'quantization to multiple nodes (e.g., conv/*/weights). When the flag is '
+ 'provided without any nodes the default behavior will match all nodes.')
+ parser.add_argument(
+ '--%s' % common.QUANTIZATION_TYPE_UINT16,
+ type=str,
+ default=None,
+ const=True,
+ nargs='?',
+ help='Comma separated list of node names to apply 2-byte affine '
+ 'quantization. You can also use wildcard symbol (*) to apply '
+ 'quantization to multiple nodes (e.g., conv/*/weights). When the flag is '
+ 'provided without any nodes the default behavior will match all nodes.')
parser.add_argument(
'--%s' % common.QUANTIZATION_BYTES,
type=int,
choices=set(quantization.QUANTIZATION_BYTES_TO_DTYPES.keys()),
- help='How many bytes to optionally quantize/compress the weights to. 1- '
- 'and 2-byte quantizaton is supported. The default (unquantized) size is '
- '4 bytes.')
+ help='(Deprecated) How many bytes to optionally quantize/compress the '
+ 'weights to. 1- and 2-byte quantizaton is supported. The default '
+ '(unquantized) size is 4 bytes.')
parser.add_argument(
'--%s' % common.SPLIT_WEIGHTS_BY_LAYER,
action='store_true',
@@ -574,9 +619,12 @@ def convert(arguments):
'but got %s' % args.weight_shard_size_bytes)
weight_shard_size_bytes = args.weight_shard_size_bytes
- quantization_dtype = (
- quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
- if args.quantization_bytes else None)
+ quantization_dtype_map = _parse_quantization_dtype_map(
+ args.quantize_float16,
+ args.quantize_uint8,
+ args.quantize_uint16,
+ args.quantization_bytes
+ )
if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL):
raise ValueError(
@@ -601,14 +649,14 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
args.input_path, output_dir=args.output_path,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
@@ -617,7 +665,7 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_saved_model_to_tensorflowjs_conversion(
args.input_path, args.output_path,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_SAVED_MODEL and
@@ -626,7 +674,7 @@ def convert(arguments):
args.input_path, args.output_path,
signature_def=args.signature_name,
saved_model_tags=args.saved_model_tags,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
@@ -637,7 +685,7 @@ def convert(arguments):
args.input_path, args.output_path,
signature=args.signature_name,
saved_model_tags=args.saved_model_tags,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
@@ -654,13 +702,13 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_tensorflowjs_to_tensorflowjs_conversion(
args.input_path, args.output_path,
- quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
+ quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
args.input_path, args.output_path,
- quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -668,7 +716,7 @@ def convert(arguments):
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_frozen_model(
args.input_path, args.output_node_names, args.output_path,
- quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
diff --git a/tfjs-converter/python/tensorflowjs/converters/converter_test.py b/tfjs-converter/python/tensorflowjs/converters/converter_test.py
index 3a8ecaf20f4..1430bc549cc 100644
--- a/tfjs-converter/python/tensorflowjs/converters/converter_test.py
+++ b/tfjs-converter/python/tensorflowjs/converters/converter_test.py
@@ -564,7 +564,8 @@ def testConvertTfKerasSequentialSavedAsSavedModelWithQuantization(self):
# Convert the keras SavedModel to tfjs format.
tfjs_output_dir = os.path.join(self._tmp_dir, 'tfjs')
converter.dispatch_keras_saved_model_to_tensorflowjs_conversion(
- self._tmp_dir, tfjs_output_dir, quantization_dtype=np.uint16)
+ self._tmp_dir, tfjs_output_dir,
+ quantization_dtype_map={'uint16': '*'})
# Verify the size of the weight file.
weight_path = glob.glob(os.path.join(tfjs_output_dir, 'group*-*'))[0]
@@ -688,7 +689,7 @@ def testConvertTfjsLayersModelWithUint16Quantization(self):
sharded_model_path = os.path.join(self._tmp_dir, 'sharded_model')
converter.dispatch_tensorflowjs_to_tensorflowjs_conversion(
os.path.join(tfjs_output_dir, 'model.json'), sharded_model_path,
- quantization_dtype=np.uint16,
+ quantization_dtype_map={'uint16': '*'},
weight_shard_size_bytes=weight_shard_size_bytes)
# Check the number of quantized files and their sizes.
@@ -723,7 +724,7 @@ def testConvertTfjsLayersModelWithUint8Quantization(self):
sharded_model_path = os.path.join(self._tmp_dir, 'sharded_model')
converter.dispatch_tensorflowjs_to_tensorflowjs_conversion(
os.path.join(tfjs_output_dir, 'model.json'), sharded_model_path,
- quantization_dtype=np.uint8,
+ quantization_dtype_map={'uint8': '*'},
weight_shard_size_bytes=weight_shard_size_bytes)
# Check the number of quantized files and their sizes.
diff --git a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py
index c25a02869aa..7a37de2d7b5 100644
--- a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py
+++ b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py
@@ -259,7 +259,7 @@ def _get_generated_by(topology):
def write_artifacts(topology,
weights,
output_dir,
- quantization_dtype=None,
+ quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Writes weights and topology to the output_dir.
@@ -269,8 +269,9 @@ def write_artifacts(topology,
topology: a JSON dictionary, representing the Keras config.
weights: an array of weight groups (as defined in tfjs write_weights).
output_dir: the directory to hold all the contents.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: (Optional) A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
"""
@@ -297,7 +298,7 @@ def write_artifacts(topology,
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
weights_manifest = write_weights.write_weights(
weights, output_dir, write_manifest=False,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
shard_size_bytes=weight_shard_size_bytes)
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
@@ -308,7 +309,7 @@ def write_artifacts(topology,
json.dump(model_json, f)
-def save_keras_model(model, artifacts_dir, quantization_dtype=None,
+def save_keras_model(model, artifacts_dir, quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
r"""Save a Keras model and its weights in TensorFlow.js format.
@@ -326,8 +327,9 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None,
- files containing weight values in groups, with the file name pattern
group(\d+)-shard(\d+)of(\d+).
If the directory does not exist, this function will attempt to create it.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: (Optional) A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
@@ -344,6 +346,6 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None,
os.makedirs(artifacts_dir)
write_artifacts(
topology_json, weight_groups, artifacts_dir,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
os.remove(temp_h5_path)
diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py
index 6f2377447c2..0041929adb7 100644
--- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py
+++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py
@@ -110,8 +110,8 @@ def _run_grappler(config, graph_def, graph, signature_def):
config, meta_graph, cluster=get_cluster())
def optimize_graph(graph, signature_def, output_graph,
- tf_version, quantization_dtype=None, skip_op_check=False,
- strip_debug_ops=False,
+ tf_version, quantization_dtype_map=None,
+ skip_op_check=False, strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Takes a Python Graph object and optimizes the graph.
@@ -120,8 +120,9 @@ def optimize_graph(graph, signature_def, output_graph,
signature_def: the SignatureDef of the inference graph.
output_graph: The location of the output graph.
tf_version: Tensorflow version of the input graph.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -192,7 +193,7 @@ def optimize_graph(graph, signature_def, output_graph,
extract_weights(
optimized_graph, output_graph, tf_version,
- signature_def, quantization_dtype, weight_shard_size_bytes)
+ signature_def, quantization_dtype_map, weight_shard_size_bytes)
return optimize_graph
@@ -230,7 +231,7 @@ def extract_weights(graph_def,
output_graph,
tf_version,
signature_def,
- quantization_dtype=None,
+ quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Takes a Python GraphDef object and extract the weights.
@@ -239,8 +240,10 @@ def extract_weights(graph_def,
the model topology.
tf_version: Tensorflow version of the input graph.
signature_def: the SignatureDef of the inference graph.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ compression. Only np.uint8 and np.uint16 are supported.
+ supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
"""
@@ -260,7 +263,7 @@ def extract_weights(graph_def,
[global_manifest + function_manifests],
output_graph,
tf_version, signature_def,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -269,7 +272,7 @@ def write_artifacts(topology,
output_graph,
tf_version,
signature_def,
- quantization_dtype=None,
+ quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Writes weights and topology to the output_dir.
@@ -282,8 +285,9 @@ def write_artifacts(topology,
output_graph: the output file name to hold all the contents.
tf_version: Tensorflow version of the input graph.
signature_def: the SignatureDef of the inference graph.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.
"""
@@ -300,7 +304,7 @@ def write_artifacts(topology,
model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None
weights_manifest = write_weights.write_weights(
weights, os.path.dirname(output_graph), write_manifest=False,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
shard_size_bytes=weight_shard_size_bytes)
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest
@@ -378,7 +382,8 @@ def _build_signature_def(frozen_graph, input_nodes, output_nodes):
def convert_tf_frozen_model(frozen_model_path,
output_node_names,
- output_dir, quantization_dtype=None,
+ output_dir,
+ quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
@@ -392,8 +397,9 @@ def convert_tf_frozen_model(frozen_model_path,
will consist of
- a file named 'model.json'
- possibly sharded binary weight files.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -410,7 +416,7 @@ def convert_tf_frozen_model(frozen_model_path,
optimize_graph(graph, signature,
output_graph, tf.__version__,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -418,7 +424,7 @@ def convert_tf_frozen_model(frozen_model_path,
def convert_tf_saved_model(saved_model_dir,
output_dir, signature_def='serving_default',
saved_model_tags='serve',
- quantization_dtype=None,
+ quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
@@ -438,8 +444,9 @@ def convert_tf_saved_model(saved_model_dir,
signature_def: string Tagset of the SignatureDef to load. Defaults to
'serving_default'.
saved_model_tags: tags of the GraphDef to load. Defaults to 'serve'.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -481,7 +488,7 @@ def convert_tf_saved_model(saved_model_dir,
optimize_graph(frozen_graph, signature,
output_graph, model.tensorflow_version,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -533,7 +540,7 @@ def load_and_initialize_hub_module(module_path, signature='default'):
def convert_tf_hub_module_v1(module_path, output_dir,
- signature='default', quantization_dtype=None,
+ signature='default', quantization_dtype_map=None,
skip_op_check=False, strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Freeze the TF-Hub module and check compatibility with Tensorflow.js.
@@ -548,6 +555,9 @@ def convert_tf_hub_module_v1(module_path, output_dir,
- a file named 'model.json'
- possibly sharded binary weight files.
signature: string Signature to load.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -589,7 +599,7 @@ def convert_tf_hub_module_v1(module_path, output_dir,
optimize_graph(frozen_graph, signature,
output_graph, tf.__version__,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
@@ -601,8 +611,8 @@ def convert_tf_hub_module_v1(module_path, output_dir,
def convert_tf_hub_module(module_handle, output_dir,
signature='default', saved_model_tags='serve',
- quantization_dtype=None, skip_op_check=False,
- strip_debug_ops=False,
+ quantization_dtype_map=None,
+ skip_op_check=False, strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
control_flow_v2=False):
"""Conversion for TF Hub modules V1 and V2.
@@ -617,6 +627,9 @@ def convert_tf_hub_module(module_handle, output_dir,
- possibly sharded binary weight files.
signature: string Signature to load.
saved_model_tags: tags of the GraphDef to load. Defaults to ''.
+ quantization_dtype_map: A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
@@ -630,7 +643,8 @@ def convert_tf_hub_module(module_handle, output_dir,
if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)):
print("Loading the module using TF 1.X interface from %s." % module_path)
convert_tf_hub_module_v1(module_path, output_dir, signature,
- quantization_dtype, skip_op_check, strip_debug_ops,
+ quantization_dtype_map,
+ skip_op_check, strip_debug_ops,
weight_shard_size_bytes)
else:
print("Loading the module using TF 2.X interface from %s." % module_path)
@@ -640,7 +654,7 @@ def convert_tf_hub_module(module_handle, output_dir,
output_dir=output_dir,
signature_def=signature,
saved_model_tags=saved_model_tags,
- quantization_dtype=quantization_dtype,
+ quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard.py b/tfjs-converter/python/tensorflowjs/converters/wizard.py
index cdfea64a0c7..9e23f99611c 100644
--- a/tfjs-converter/python/tensorflowjs/converters/wizard.py
+++ b/tfjs-converter/python/tensorflowjs/converters/wizard.py
@@ -215,7 +215,7 @@ def generate_arguments(params):
"""
args = []
not_param_list = [common.INPUT_PATH, common.OUTPUT_PATH,
- 'overwrite_output_path']
+ 'overwrite_output_path', 'quantize']
no_false_param = [common.SPLIT_WEIGHTS_BY_LAYER, common.SKIP_OP_CHECK]
for key, value in sorted(params.items()):
if key not in not_param_list and value is not None:
@@ -463,20 +463,54 @@ def run(dryrun):
},
{
'type': 'list',
- 'name': common.QUANTIZATION_BYTES,
+ 'name': 'quantize',
'message': 'Do you want to compress the model? '
'(this will decrease the model precision.)',
'choices': [{
'name': 'No compression (Higher accuracy)',
'value': None
}, {
- 'name': '2x compression (Accuracy/size trade-off)',
- 'value': 2
+ 'name': 'float16 quantization '
+ '(2x smaller, Minimal accuracy loss)',
+ 'value': 'float16'
}, {
- 'name': '4x compression (Smaller size)',
- 'value': 1
+ 'name': 'uint16 affine quantization (2x smaller, Accuracy loss)',
+ 'value': 'uint16'
+ }, {
+ 'name': 'uint8 affine quantization (4x smaller, Accuracy loss)',
+ 'value': 'uint8'
}]
},
+ {
+ 'type': 'input',
+ 'name': common.QUANTIZATION_TYPE_FLOAT16,
+ 'message': 'Please enter the layers to apply float16 quantization '
+ '(2x smaller, minimal accuracy tradeoff).\n'
+ 'Supports wildcard expansion with *, e.g., conv/*/weights',
+ 'default': '*',
+ 'when': lambda answers:
+ value_in_list(answers, 'quantize', ('float16'))
+ },
+ {
+ 'type': 'input',
+ 'name': common.QUANTIZATION_TYPE_UINT8,
+ 'message': 'Please enter the layers to apply affine 1-byte integer '
+ 'quantization (4x smaller, accuracy tradeoff).\n'
+ 'Supports wildcard expansion with *, e.g., conv/*/weights',
+ 'default': '*',
+ 'when': lambda answers:
+ value_in_list(answers, 'quantize', ('uint8'))
+ },
+ {
+ 'type': 'input',
+ 'name': common.QUANTIZATION_TYPE_UINT16,
+ 'message': 'Please enter the layers to apply affine 2-byte integer '
+ 'quantization (2x smaller, accuracy tradeoff).\n'
+ 'Supports wildcard expansion with *, e.g., conv/*/weights',
+ 'default': '*',
+ 'when': lambda answers:
+ value_in_list(answers, 'quantize', ('uint16'))
+ },
{
'type': 'input',
'name': common.WEIGHT_SHARD_SIZE_BYTES,
diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py
index c50545cde0d..d724e79401f 100644
--- a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py
+++ b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py
@@ -183,7 +183,7 @@ def testGenerateCommandForSavedModel(self):
'input_path': 'tmp/saved_model',
'saved_model_tags': 'test',
'signature_name': 'test_default',
- 'quantization_bytes': 2,
+ 'quantize_float16': 'conv/*/weights',
'weight_shard_size_bytes': '4194304',
'skip_op_check': False,
'strip_debug_ops': True,
@@ -192,8 +192,10 @@ def testGenerateCommandForSavedModel(self):
self.assertEqual(['--control_flow_v2=True',
'--input_format=tf_saved_model',
- '--quantization_bytes=2', '--saved_model_tags=test',
- '--signature_name=test_default', '--strip_debug_ops=True',
+ '--quantize_float16=conv/*/weights',
+ '--saved_model_tags=test',
+ '--signature_name=test_default',
+ '--strip_debug_ops=True',
'--weight_shard_size_bytes=4194304',
'tmp/saved_model', 'tmp/web_model'],
wizard.generate_arguments(options))
@@ -205,7 +207,7 @@ def testGenerateCommandForKerasSavedModel(self):
'saved_model_tags': 'test',
'signature_name': 'test_default',
'weight_shard_size_bytes': '100',
- 'quantization_bytes': 1,
+ 'quantize_float16': 'conv/*/weights',
'skip_op_check': True,
'strip_debug_ops': False,
'control_flow_v2': False,
@@ -214,7 +216,8 @@ def testGenerateCommandForKerasSavedModel(self):
self.assertEqual(['--control_flow_v2=False',
'--input_format=tf_keras_saved_model',
'--output_format=tfjs_layers_model',
- '--quantization_bytes=1', '--saved_model_tags=test',
+ '--quantize_float16=conv/*/weights',
+ '--saved_model_tags=test',
'--signature_name=test_default', '--skip_op_check',
'--strip_debug_ops=False',
'--weight_shard_size_bytes=100',
@@ -225,10 +228,11 @@ def testGenerateCommandForKerasModel(self):
options = {'input_format': 'keras',
'input_path': 'tmp/model.HD5',
'weight_shard_size_bytes': '100',
- 'quantization_bytes': 1,
+ 'quantize_uint16': 'conv/*/weights',
'output_path': 'tmp/web_model'}
- self.assertEqual(['--input_format=keras', '--quantization_bytes=1',
+ self.assertEqual(['--input_format=keras',
+ '--quantize_uint16=conv/*/weights',
'--weight_shard_size_bytes=100',
'tmp/model.HD5', 'tmp/web_model'],
wizard.generate_arguments(options))
@@ -237,13 +241,13 @@ def testGenerateCommandForLayerModel(self):
options = {'input_format': 'tfjs_layers_model',
'output_format': 'keras',
'input_path': 'tmp/model.json',
- 'quantization_bytes': 1,
+ 'quantize_uint8': 'conv/*/weights',
'weight_shard_size_bytes': '100',
'output_path': 'tmp/web_model'}
self.assertEqual(['--input_format=tfjs_layers_model',
'--output_format=keras',
- '--quantization_bytes=1',
+ '--quantize_uint8=conv/*/weights',
'--weight_shard_size_bytes=100',
'tmp/model.json',
'tmp/web_model'],
diff --git a/tfjs-converter/python/tensorflowjs/quantization.py b/tfjs-converter/python/tensorflowjs/quantization.py
index a26259067c6..11b535334bb 100644
--- a/tfjs-converter/python/tensorflowjs/quantization.py
+++ b/tfjs-converter/python/tensorflowjs/quantization.py
@@ -16,10 +16,79 @@
from __future__ import division
from __future__ import print_function
+import fnmatch
import numpy as np
-QUANTIZATION_BYTES_TO_DTYPES = {1: np.uint8, 2: np.uint16}
+QUANTIZATION_DTYPE_FLOAT16 = 'float16'
+QUANTIZATION_DTYPE_UINT8 = 'uint8'
+QUANTIZATION_DTYPE_UINT16 = 'uint16'
+QUANTIZATION_BYTES_TO_DTYPES = {1: QUANTIZATION_DTYPE_UINT8,
+ 2: QUANTIZATION_DTYPE_UINT16}
+QUANTIZATION_OPTION_TO_DTYPES = {QUANTIZATION_DTYPE_UINT8: np.uint8,
+ QUANTIZATION_DTYPE_UINT16: np.uint16,
+ QUANTIZATION_DTYPE_FLOAT16: np.float16}
+
+
+def map_layers_to_quantization_dtype(names, quantization_dtype_map):
+ """Maps node names to their quantization dtypes.
+
+ Given a quantization_dtype_map which maps dtypes `uint8`, `uint16`, `float16`
+ to node patterns, e.g., conv/*/weights we construct a new mapping for each
+ individual node name to its dtype, e.g., conv/1/weight -> `uint8`.
+ A dtype in the map can also be a boolean, signaling a fallthrough dtype.
+ There can only be one fallthrough dtype in the map. A fallthrough dtype
+ will convert all weights that don't match any pattern to the provided dtype.
+
+ Args:
+ names: Array of node names.
+ quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
+ to weights. The weight mapping supports wildcard substitution.
+
+ Returns:
+ quantization_dtype: A mapping from each node name which matches
+ an entry in quantization_dtype_map to its corresponding dtype.
+
+ Raises:
+ ValueError: - If multiple dtypes match the same node name
+ - If more than one fallthrough is provided
+ """
+ if quantization_dtype_map is None:
+ return {}
+
+ fallthrough = None
+ quantization_dtype = {}
+ for dtype_name, patterns in quantization_dtype_map.items():
+ # Record fallthrough if there is one
+ if isinstance(patterns, bool) and patterns:
+ # Only one fallthrough is supported
+ if fallthrough is not None:
+ raise ValueError(
+ 'More than one quantization fallthrough provided, '
+ 'exactly one is supported')
+ fallthrough = dtype_name
+ continue
+ elif isinstance(patterns, str):
+ patterns = list([patterns])
+
+ # Record matched weights for dtype
+ for pattern in patterns:
+ for match in fnmatch.filter(names, pattern):
+ dtype = QUANTIZATION_OPTION_TO_DTYPES[dtype_name]
+ if match in quantization_dtype and quantization_dtype[match] != dtype:
+ raise ValueError(
+ 'Two quantization values %s, %s match the same node %s' %
+ (dtype, quantization_dtype[match], match))
+ quantization_dtype[match] = dtype
+
+ # Catch all remaining names with fallthrough
+ if fallthrough is not None:
+ nameset = set(names)
+ fallthrough_names = nameset - set(quantization_dtype.keys())
+ for name in fallthrough_names:
+ quantization_dtype[name] = QUANTIZATION_OPTION_TO_DTYPES[fallthrough]
+
+ return quantization_dtype
def quantize_weights(data, quantization_dtype):
"""Quantizes the weights by linearly re-scaling across available bits.
@@ -36,43 +105,71 @@ def quantize_weights(data, quantization_dtype):
Args:
data: A numpy array of dtype 'float32' or 'int32'.
- quantization_dtype: A numpy dtype to quantize weights to. Only np.uint8 and
- np.uint16 are supported.
+ quantization_dtype: A numpy dtype to quantize weights to. Only np.float16,
+ np.uint8, and np.uint16 are supported.
Returns:
quantized_data: The quantized weights as a numpy array with dtype
`quantization_dtype`.
- scale: The linearly scaling constant used for quantization.
- min_val: The minimum value of the linear range.
+ metadata: A dictionary with the corresponding metadata for the quantization
+ type. There is no metadata associated with float16.
+ For affine quantization there are two associated metadata values:
+ scale: The linearly scaling constant used for quantization.
+ min_val: The minimum value of the linear range.
Raises:
ValueError: if `quantization_dtype` is not a valid type.
"""
- if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values():
+ if quantization_dtype in [np.uint8, np.uint16]:
+ # Compute the min and max for the group.
+ min_val = data.min().astype(np.float64)
+ max_val = data.max().astype(np.float64)
+ if min_val == max_val:
+ # If there is only a single value, we can represent everything as zeros.
+ quantized_data = np.zeros_like(data, dtype=quantization_dtype)
+ scale = 1.0
+ else:
+ # Quantize data.
+ scale, min_val, max_val = _get_affine_quantization_range(
+ min_val, max_val, quantization_dtype)
+ quantized_data = np.round(
+ (data.clip(min_val, max_val) - min_val) / scale).astype(
+ quantization_dtype)
+
+ return quantized_data, {'min': min_val, 'scale': scale}
+ elif quantization_dtype == np.float16:
+ if data.dtype != np.float32:
+ raise ValueError(
+ 'Invalid data dtype %r\n'
+ 'float16 quantization only supports float32 dtype' % data.dtype)
+ quantized_data = data.astype(np.float16)
+ return quantized_data, {}
+ else:
raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype)
- # Compute the min and max for the group.
- min_val = data.min().astype(np.float64)
- max_val = data.max().astype(np.float64)
- if min_val == max_val:
- # If there is only a single value, we can represent everything as zeros.
- quantized_data = np.zeros_like(data, dtype=quantization_dtype)
- scale = 1.0
- else:
- # Quantize data.
- scale, min_val, max_val = _get_quantization_range(
- min_val, max_val, quantization_dtype)
- quantized_data = np.round(
- (data.clip(min_val, max_val) - min_val) / scale).astype(
- quantization_dtype)
- return quantized_data, scale, min_val
+def dequantize_weights(data, metadata, original_dtype=np.float32):
+ dtype = data.dtype
-def dequantize_weights(
- quantized_data, scale, min_val, original_dtype=np.float32):
- return np.round(quantized_data * scale + min_val).astype(original_dtype)
+ if dtype in [np.uint8, np.uint16]:
+ if not ('scale' in metadata and 'min' in metadata):
+ raise ValueError(
+ 'Missing metadata min or scale for dtype %s' % dtype.name)
+ scale = metadata['scale']
+ min_val = metadata['min']
+ return np.round(data * scale + min_val).astype(original_dtype)
+ elif dtype == np.float16:
+ if original_dtype != np.float32:
+ raise ValueError(
+ 'Invalid data dtype %r\n'
+ 'float16 quantization only supports float32 dtype' % data.dtype)
+ return data.astype(original_dtype)
+ else:
+ raise ValueError(
+ 'Invalid dtype %s for dequantization\n'
+ 'Supported dtypes are uint8, uint16, float16' % dtype.name)
-def _get_quantization_range(min_val, max_val, quantization_dtype):
+def _get_affine_quantization_range(min_val, max_val, quantization_dtype):
"""Computes quantization range to ensure that zero is represented if covered.
Gymnastics with nudged zero point is to ensure that real zero maps to an
@@ -97,7 +194,7 @@ def _get_quantization_range(min_val, max_val, quantization_dtype):
Raises:
ValueError: if `quantization_dtype` is not a valid type.
"""
- if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values():
+ if quantization_dtype not in [np.uint8, np.uint16]:
raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype)
quant_max = np.iinfo(quantization_dtype).max
diff --git a/tfjs-converter/python/tensorflowjs/quantization_test.py b/tfjs-converter/python/tensorflowjs/quantization_test.py
index 13207e8e8ea..4762b5ecbb2 100644
--- a/tfjs-converter/python/tensorflowjs/quantization_test.py
+++ b/tfjs-converter/python/tensorflowjs/quantization_test.py
@@ -24,71 +24,219 @@
class TestQuantizationUtil(unittest.TestCase):
+ def assertDictContainsSubsetAlmostEqual(self, d1, d2):
+ self.assertIsInstance(d1, dict)
+ self.assertIsInstance(d2, dict)
+
+ d1_keys = set(d1.keys())
+ d2_keys = set(d2.keys())
+
+ self.assertTrue(d2_keys.issubset(d1_keys))
+
+ for key in d2_keys:
+ self.assertAlmostEqual(d1[key], d2[key])
+
+
def _runQuantizeTest(
self, range_min, range_max, data_dtype, quantization_dtype,
- expected_scale):
+ expected_metadata):
d = np.arange(range_min, range_max + 1, dtype=data_dtype)
- q, s, m = quantization.quantize_weights(d, quantization_dtype)
- self.assertAlmostEqual(s, expected_scale)
+ q, metadata = quantization.quantize_weights(d, quantization_dtype)
+
+ self.assertDictContainsSubsetAlmostEqual(metadata, expected_metadata)
self.assertEqual(q.dtype, quantization_dtype)
- de_q = quantization.dequantize_weights(q, s, m, data_dtype)
+ de_q = quantization.dequantize_weights(
+ q, metadata, data_dtype)
np.testing.assert_allclose(de_q, d)
- if range_min <= 0 <= range_max:
- d_0 = np.zeros(1, data_dtype)
- q_0 = np.round((d_0 - m) / s).astype(quantization_dtype)
- self.assertEqual(
- quantization.dequantize_weights(q_0, s, m, data_dtype), d_0)
+ if quantization_dtype in [np.uint8, np.uint16]:
+ s = metadata['scale']
+ m = metadata['min']
+ if range_min <= 0 <= range_max:
+ d_0 = np.zeros(1, data_dtype)
+ q_0 = np.round((d_0 - m) / s).astype(quantization_dtype)
+ self.assertEqual(
+ quantization.dequantize_weights(q_0, metadata, data_dtype), d_0)
- def testAllEqual(self):
+ def testAffineQuantizeAllEqual(self):
d = np.ones(5, dtype=np.float32)
- q, s, m = quantization.quantize_weights(d, np.uint8)
- self.assertEqual(s, 1.0)
+ q, metadata = quantization.quantize_weights(d, np.uint8)
+ assert 'scale' in metadata and 'min' in metadata
+ self.assertEqual(metadata['scale'], 1.0)
self.assertEqual(q.dtype, np.uint8)
- de_q = quantization.dequantize_weights(q, s, m, np.float32)
+ de_q = quantization.dequantize_weights(q, metadata, np.float32)
np.testing.assert_array_equal(de_q, d)
- def testQuantizeNegativeFloats(self):
- self._runQuantizeTest(-3, -1, np.float32, np.uint8, expected_scale=2/255)
- self._runQuantizeTest(-3, -1, np.float32, np.uint16, expected_scale=2/65536)
-
- def testQuantizeNegativeAndZeroFloats(self):
- self._runQuantizeTest(-3, 0, np.float32, np.uint8, expected_scale=3/255)
- self._runQuantizeTest(-3, 0, np.float32, np.uint16, expected_scale=3/65536)
-
- def testQuantizeNegativeAndPositiveFloats(self):
- self._runQuantizeTest(-3, 3, np.float32, np.uint8, expected_scale=6/255)
- self._runQuantizeTest(-3, 3, np.float32, np.uint16, expected_scale=6/65536)
-
- def testQuantizeZeroAndPositiveFloats(self):
- self._runQuantizeTest(0, 3, np.float32, np.uint8, expected_scale=3/255)
- self._runQuantizeTest(0, 3, np.float32, np.uint16, expected_scale=3/65536)
-
- def testQuantizePositiveFloats(self):
- self._runQuantizeTest(1, 3, np.float32, np.uint8, expected_scale=2/255)
- self._runQuantizeTest(1, 3, np.float32, np.uint16, expected_scale=2/65536)
-
- def testQuantizeNegativeInts(self):
- self._runQuantizeTest(-3, -1, np.int32, np.uint8, expected_scale=2/255)
- self._runQuantizeTest(-3, -1, np.int32, np.uint16, expected_scale=2/65536)
-
- def testQuantizeNegativeAndZeroInts(self):
- self._runQuantizeTest(-3, 0, np.int32, np.uint8, expected_scale=3/255)
- self._runQuantizeTest(-3, 0, np.int32, np.uint16, expected_scale=3/65536)
+ def testFloatQuantizeAllEqual(self):
+ d = np.ones(5, dtype=np.float32)
+ q, metadata = quantization.quantize_weights(d, np.float16)
+ self.assertDictEqual(metadata, {})
- def testQuantizeNegativeAndPositiveInts(self):
- self._runQuantizeTest(-3, 3, np.int32, np.uint8, expected_scale=6/255)
- self._runQuantizeTest(-3, 3, np.int32, np.uint16, expected_scale=6/65536)
+ self.assertEqual(q.dtype, np.float16)
+ de_q = quantization.dequantize_weights(q, metadata, np.float32)
+ np.testing.assert_array_equal(de_q, d)
- def testQuantizeZeroAndPositiveInts(self):
- self._runQuantizeTest(0, 3, np.int32, np.uint8, expected_scale=3/255)
- self._runQuantizeTest(0, 3, np.int32, np.uint16, expected_scale=3/65536)
+ def testAffineQuantizeNegativeFloats(self):
+ self._runQuantizeTest(
+ -3, -1, np.float32, np.uint8,
+ expected_metadata={'scale': 2/255})
+ self._runQuantizeTest(
+ -3, -1, np.float32, np.uint16,
+ expected_metadata={'scale': 2/65536})
+
+ def testAffineQuantizeNegativeAndZeroFloats(self):
+ self._runQuantizeTest(
+ -3, 0, np.float32, np.uint8,
+ expected_metadata={'scale': 3/255})
+ self._runQuantizeTest(
+ -3, 0, np.float32, np.uint16,
+ expected_metadata={'scale': 3/65536})
+
+ def testAffineQuantizeNegativeAndPositiveFloats(self):
+ self._runQuantizeTest(
+ -3, 3, np.float32, np.uint8,
+ expected_metadata={'scale': 6/255})
+ self._runQuantizeTest(
+ -3, 3, np.float32, np.uint16,
+ expected_metadata={'scale': 6/65536})
+
+ def testAffineQuantizeZeroAndPositiveFloats(self):
+ self._runQuantizeTest(
+ 0, 3, np.float32, np.uint8,
+ expected_metadata={'scale': 3/255})
+ self._runQuantizeTest(
+ 0, 3, np.float32, np.uint16,
+ expected_metadata={'scale': 3/65536})
+
+ def testAffineQuantizePositiveFloats(self):
+ self._runQuantizeTest(
+ 1, 3, np.float32, np.uint8,
+ expected_metadata={'scale': 2/255})
+ self._runQuantizeTest(
+ 1, 3, np.float32, np.uint16,
+ expected_metadata={'scale': 2/65536})
+
+ def testAffineQuantizeNegativeInts(self):
+ self._runQuantizeTest(
+ -3, -1, np.int32, np.uint8,
+ expected_metadata={'scale': 2/255})
+ self._runQuantizeTest(
+ -3, -1, np.int32, np.uint16,
+ expected_metadata={'scale': 2/65536})
+
+ def testAffineQuantizeNegativeAndZeroInts(self):
+ self._runQuantizeTest(
+ -3, 0, np.int32, np.uint8,
+ expected_metadata={'scale': 3/255})
+ self._runQuantizeTest(
+ -3, 0, np.int32, np.uint16,
+ expected_metadata={'scale': 3/65536})
+
+ def testAffineQuantizeNegativeAndPositiveInts(self):
+ self._runQuantizeTest(
+ -3, 3, np.int32, np.uint8,
+ expected_metadata={'scale': 6/255})
+ self._runQuantizeTest(
+ -3, 3, np.int32, np.uint16,
+ expected_metadata={'scale': 6/65536})
+
+ def testAffineQuantizeZeroAndPositiveInts(self):
+ self._runQuantizeTest(
+ 0, 3, np.int32, np.uint8,
+ expected_metadata={'scale': 3/255})
+ self._runQuantizeTest(
+ 0, 3, np.int32, np.uint16,
+ expected_metadata={'scale': 3/65536})
+
+ def testAffineQuantizePositiveInts(self):
+ self._runQuantizeTest(
+ 1, 3, np.int32, np.uint8,
+ expected_metadata={'scale': 2/255})
+ self._runQuantizeTest(
+ 1, 3, np.int32, np.uint16,
+ expected_metadata={'scale': 2/65536})
+
+ def testInvalidQuantizationTypes(self):
+ # Invalid quantization type
+ with self.assertRaises(ValueError):
+ quantization.quantize_weights(np.array([]), np.bool)
+
+ # Invalid data dtype for float16 quantization
+ with self.assertRaises(ValueError):
+ d = np.ones(1, dtype=np.int32)
+ quantization.quantize_weights(d, np.float16)
+
+ def testInvalidDequantizationTypes(self):
+ # Invalid metadata for affine quantization
+ with self.assertRaises(ValueError):
+ d = np.ones(1, dtype=np.uint8)
+ quantization.dequantize_weights(np.array([]), {})
+
+ # Invalid target dtype for float16 quantization
+ with self.assertRaises(ValueError):
+ d = np.ones(1, dtype=np.float16)
+ quantization.dequantize_weights(d, {}, np.int32)
+
+ # Invalid dequantization type
+ with self.assertRaises(ValueError):
+ d = np.ones(1, dtype=np.bool)
+ quantization.dequantize_weights(d, {})
+
+ def testMapLayerFallthrough(self):
+ names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
+ quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': True}
+ dtype_map = quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
+
+ self.assertDictEqual(dtype_map, {
+ 'conv/0/weight': np.float16,
+ 'conv/0/bias': np.float16,
+ 'conv/1/weight': np.uint8,
+ 'conv/1/bias': np.uint8
+ })
+
+ def testMapLayerConflictingMap(self):
+ names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
+ quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': ['conv/0/bias']}
+
+ with self.assertRaises(ValueError):
+ quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
+
+
+ def testMapLayerStringToList(self):
+ names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
+ quantization_dtype_map = {'float16': '*'}
+
+
+ dtype_map = quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
+
+ self.assertDictEqual(dtype_map, {
+ 'conv/0/weight': np.float16,
+ 'conv/0/bias': np.float16,
+ 'conv/1/weight': np.float16,
+ 'conv/1/bias': np.float16
+ })
+
+ def testMapLayerNoDtypeMap(self):
+ names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
+ quantization_dtype_map = {}
+ dtype_map = quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
+
+ self.assertDictEqual(dtype_map, {})
+
+ def testMapLayerExactlyOneFallthrough(self):
+ names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias']
+ quantization_dtype_map = {'float16': True, 'uint8': True}
+
+ with self.assertRaises(ValueError):
+ quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
- def testQuantizePositiveInts(self):
- self._runQuantizeTest(1, 3, np.int32, np.uint8, expected_scale=2/255)
- self._runQuantizeTest(1, 3, np.int32, np.uint16, expected_scale=2/65536)
if __name__ == '__main__':
diff --git a/tfjs-converter/python/tensorflowjs/read_weights.py b/tfjs-converter/python/tensorflowjs/read_weights.py
index 5f5d433faa9..26485b7d173 100644
--- a/tfjs-converter/python/tensorflowjs/read_weights.py
+++ b/tfjs-converter/python/tensorflowjs/read_weights.py
@@ -24,7 +24,7 @@
import numpy as np
from tensorflowjs import quantization
-_INPUT_DTYPES = [np.float32, np.int32, np.complex64,
+_INPUT_DTYPES = [np.float16, np.float32, np.int32, np.complex64,
np.uint8, np.uint16, np.object]
# Number of bytes used to encode the length of a string in a string tensor.
@@ -191,8 +191,7 @@ def decode_weights(weights_manifest, data_buffers, flatten=False):
offset += dtype.itemsize * value.size
if quant_info:
value = quantization.dequantize_weights(
- value, quant_info['scale'], quant_info['min'],
- np.dtype(weight['dtype']))
+ value, quant_info, np.dtype(weight['dtype']))
out_group.append({'name': name, 'data': value})
if flatten:
diff --git a/tfjs-converter/python/tensorflowjs/read_weights_test.py b/tfjs-converter/python/tensorflowjs/read_weights_test.py
index 61af1d3212f..3cca2170d3a 100644
--- a/tfjs-converter/python/tensorflowjs/read_weights_test.py
+++ b/tfjs-converter/python/tensorflowjs/read_weights_test.py
@@ -326,7 +326,7 @@ def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self):
read_weights.read_weights(groups[0][0], self._tmp_dir)
- def testReadQuantizedWeights(self):
+ def testReadAffineQuantizedWeights(self):
groups = [
[{
'name': 'weight1',
@@ -335,13 +335,34 @@ def testReadQuantizedWeights(self):
]
manifest = write_weights.write_weights(
- groups, self._tmp_dir, quantization_dtype=np.uint8)
+ groups, self._tmp_dir, quantization_dtype_map={'uint8': '*'})
# Read the weights using `read_weights`.
read_output = read_weights.read_weights(manifest, self._tmp_dir)
self.assertEqual(1, len(read_output))
self.assertEqual(1, len(read_output[0]))
self.assertEqual('weight1', read_output[0][0]['name'])
+ self.assertEqual(read_output[0][0]['data'].dtype, np.float32)
+ self.assertTrue(
+ np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
+
+ def testReadFloat16QuantizedWeights(self):
+ groups = [
+ [{
+ 'name': 'weight1',
+ 'data': np.array([0, 1, 2, 3], 'float32')
+ }]
+ ]
+
+ manifest = write_weights.write_weights(
+ groups, self._tmp_dir, quantization_dtype_map={'float16': '*'})
+
+ # Read the weights using `read_weights`.
+ read_output = read_weights.read_weights(manifest, self._tmp_dir)
+ self.assertEqual(1, len(read_output))
+ self.assertEqual(1, len(read_output[0]))
+ self.assertEqual('weight1', read_output[0][0]['name'])
+ self.assertEqual(read_output[0][0]['data'].dtype, np.float32)
self.assertTrue(
np.allclose(groups[0][0]['data'], read_output[0][0]['data']))
diff --git a/tfjs-converter/python/tensorflowjs/write_weights.py b/tfjs-converter/python/tensorflowjs/write_weights.py
index fbbcbc44cd0..7603fb67014 100644
--- a/tfjs-converter/python/tensorflowjs/write_weights.py
+++ b/tfjs-converter/python/tensorflowjs/write_weights.py
@@ -24,7 +24,7 @@
from tensorflowjs import quantization
from tensorflowjs import read_weights
-_OUTPUT_DTYPES = [np.float32, np.int32, np.complex64,
+_OUTPUT_DTYPES = [np.float16, np.float32, np.int32, np.complex64,
np.uint8, np.uint16, np.bool, np.object]
_AUTO_DTYPE_CONVERSION = {
np.dtype(np.float64): np.float32,
@@ -33,7 +33,7 @@
def write_weights(
weight_groups, write_dir, shard_size_bytes=1024 * 1024 * 4,
- write_manifest=True, quantization_dtype=None):
+ write_manifest=True, quantization_dtype_map=None):
"""Writes weights to a binary format on disk for ingestion by JavaScript.
Weights are organized into groups. When writing to disk, the bytes from all
@@ -44,7 +44,7 @@ def write_weights(
shard size.
Weights are optionally quantized to either 8 or 16 bits for compression,
- which is enabled via the `quantization_dtype` argument.
+ which is enabled via the `quantization_dtype_map`.
Args:
weight_groups: An list of groups. Each group is an array of weight
@@ -68,8 +68,9 @@ def write_weights(
the max file size for caching for all major browsers.
write_manifest: Whether to write the manifest JSON to disk. Defaults to
True.
- quantization_dtype: An optional numpy dtype to quantize weights to for
- compression. Only np.uint8 and np.uint16 are supported.
+ quantization_dtype_map: (Optional) A mapping from dtype
+ (`uint8`, `uint16`, `float16`) to weights names. The weight mapping
+ supports wildcard substitution.
Returns:
The weights manifest JSON dict.
@@ -107,7 +108,7 @@ def write_weights(
'name': 'weight2',
'shape': [2000, 2000],
'dtype': 'float32',
- 'quantization': {'min': -2.4, 'scale': 0.08, 'dtype': 'uint8'}
+ 'quantization': {'dtype': 'float16'}
}]
}]
"""
@@ -120,8 +121,14 @@ def write_weights(
for group_index, group in enumerate(weight_groups):
for e in group:
_auto_convert_weight_entry(e)
- if quantization_dtype:
- group = [_quantize_entry(e, quantization_dtype) for e in group]
+ names = [entry['name'] for entry in group]
+ quantization_dtype = quantization.map_layers_to_quantization_dtype(
+ names, quantization_dtype_map)
+
+ group = [
+ _quantize_entry(e, quantization_dtype[e['name']])
+ if e['name'] in quantization_dtype else e for e in group
+ ]
group_bytes, total_bytes, _ = _stack_group_bytes(group)
shard_filenames = _shard_group_bytes_to_disk(
@@ -154,8 +161,8 @@ def _quantize_entry(entry, quantization_dtype):
Args:
entry: A weight entries to quantize.
- quantization_dtype: An numpy dtype to quantize weights to. Only np.uint8 and
- np.uint16 are supported.
+ quantization_dtype: An numpy dtype to quantize weights to.
+ Only np.uint8, np.uint16, and np.float16 are supported.
Returns:
A new entry containing the quantized data and additional quantization info,
@@ -168,19 +175,19 @@ def _quantize_entry(entry, quantization_dtype):
'name': 'weight1',
'data': np.array([20, 0, 255], 'uint8')
'quantization': {'min': -0.10196078817, 'scale': 0.00509803940852,
- 'original_dtype': 'float32'}
+ 'dtype': 'uint8', 'original_dtype': 'float32'}
}
"""
data = entry['data']
# Only float32 tensors are quantized.
if data.dtype != 'float32':
return entry
- quantized_data, scale, min_val = quantization.quantize_weights(
+ quantized_data, metadata = quantization.quantize_weights(
data, quantization_dtype)
+ metadata.update({'original_dtype': data.dtype.name})
quantized_entry = entry.copy()
quantized_entry['data'] = quantized_data
- quantized_entry['quantization'] = {
- 'min': min_val, 'scale': scale, 'original_dtype': data.dtype.name}
+ quantized_entry['quantization'] = metadata
return quantized_entry
@@ -322,11 +329,9 @@ def _get_weights_manifest_for_group(group):
if dtype == 'object':
var_manifest['dtype'] = 'string'
if is_quantized:
- var_manifest['quantization'] = {
- 'min': entry['quantization']['min'],
- 'scale': entry['quantization']['scale'],
- 'dtype': entry['data'].dtype.name
- }
+ manifest = {'dtype': entry['data'].dtype.name}
+ manifest.update(entry['quantization'])
+ var_manifest['quantization'] = manifest
weights_entries.append(var_manifest)
return weights_entries
diff --git a/tfjs-converter/python/tensorflowjs/write_weights_test.py b/tfjs-converter/python/tensorflowjs/write_weights_test.py
index 2ea37b5b159..3459edd41d5 100644
--- a/tfjs-converter/python/tensorflowjs/write_weights_test.py
+++ b/tfjs-converter/python/tensorflowjs/write_weights_test.py
@@ -717,7 +717,11 @@ def test_quantize_group(self):
]
manifest = write_weights.write_weights(
- groups, TMP_DIR, shard_size_bytes=1024, quantization_dtype=np.uint8)
+ groups, TMP_DIR, shard_size_bytes=1024,
+ quantization_dtype_map={
+ 'float16': 'weight1',
+ 'uint8': 'weight3'
+ })
self.assertTrue(
os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')),
@@ -731,7 +735,8 @@ def test_quantize_group(self):
'shape': [3],
'dtype': 'float32',
'quantization': {
- 'min': 1.0, 'scale': 2/255.0, 'dtype': 'uint8'
+ 'original_dtype': 'float32',
+ 'dtype': 'float16'
}
}, {
'name': 'weight2',
@@ -742,7 +747,10 @@ def test_quantize_group(self):
'shape': [2],
'dtype': 'float32',
'quantization': {
- 'min': 6.0, 'scale': 1/255.0, 'dtype': 'uint8'
+ 'min': 6.0,
+ 'scale': 1/255.0,
+ 'original_dtype': 'float32',
+ 'dtype': 'uint8'
}
}, {
'name': 'weight4',
@@ -754,19 +762,19 @@ def test_quantize_group(self):
weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin')
with open(weights_path, 'rb') as f:
weight_bytes = f.read()
- self.assertEqual(len(weight_bytes), 22)
- w1 = np.frombuffer(weight_bytes[:3], 'uint8')
- np.testing.assert_array_equal(w1, np.array([0, 127, 255], 'uint8'))
+ self.assertEqual(len(weight_bytes), 25)
+ w1 = np.frombuffer(weight_bytes[:6], 'float16')
+ np.testing.assert_array_equal(w1, np.array([1, 2, 3], 'float16'))
- w2 = np.frombuffer(weight_bytes[3:11], 'int32')
+ w2 = np.frombuffer(weight_bytes[6:14], 'int32')
np.testing.assert_array_equal(w2, np.array([4, 5], 'int32'))
- w3 = np.frombuffer(weight_bytes[11:13], 'uint8')
+ w3 = np.frombuffer(weight_bytes[14:16], 'uint8')
np.testing.assert_array_equal(w3, np.array([0, 255], 'uint8'))
- size = np.frombuffer(weight_bytes[13:17], 'uint32')[0]
+ size = np.frombuffer(weight_bytes[16:20], 'uint32')[0]
self.assertEqual(size, 5) # 5 ascii letters.
- w4 = weight_bytes[17:].decode('utf-8')
+ w4 = weight_bytes[20:].decode('utf-8')
self.assertEqual(w4, u'hello')
diff --git a/tfjs-converter/python/test_pip_package.py b/tfjs-converter/python/test_pip_package.py
index 8926417e8e4..f8b6a473b42 100644
--- a/tfjs-converter/python/test_pip_package.py
+++ b/tfjs-converter/python/test_pip_package.py
@@ -871,7 +871,7 @@ def testConvertTfjsLayersModelIntoShardedWeights(self):
new_y = new_model.predict(x)
self.assertAllClose(new_y, y)
- def testConvertTfjsLayersModelWithQuantization(self):
+ def testConvertTfjsLayersModelWithLegacyQuantization(self):
with tf.Graph().as_default(), tf.compat.v1.Session():
# 1. Saved the model as a SavedModel.
model = self._createNestedSequentialModel()
@@ -911,6 +911,47 @@ def testConvertTfjsLayersModelWithQuantization(self):
# The size of the weight file should reflect the uint16 quantization.
self.assertEqual(weight_file_size, total_weight_bytes // 2)
+
+ def testConvertTfjsLayersModelWithQuantization(self):
+ with tf.Graph().as_default(), tf.compat.v1.Session():
+ # 1. Saved the model as a SavedModel.
+ model = self._createNestedSequentialModel()
+
+ weights = model.get_weights()
+ total_weight_bytes = sum(np.size(w) for w in weights) * 4
+
+ tf.keras.models.save_model(model, self._tmp_dir)
+
+ # 2. Convert the keras saved model to tfjs_layers_model format.
+ tfjs_output_dir = os.path.join(self._tmp_dir, 'tfjs')
+ # Implicit value of --output_format: tfjs_layers_model
+ process = subprocess.Popen([
+ 'tensorflowjs_converter', '--input_format', 'keras_saved_model',
+ self._tmp_dir, tfjs_output_dir
+ ])
+ process.communicate()
+ self.assertEqual(0, process.returncode)
+
+ # 3. Convert the tfjs_layers_model to another tfjs_layers_model,
+ # with uint16 quantization.
+ sharded_model_dir = os.path.join(self._tmp_dir, 'tfjs_sharded')
+ process = subprocess.Popen([
+ 'tensorflowjs_converter', '--input_format', 'tfjs_layers_model',
+ '--output_format', 'tfjs_layers_model',
+ '--quantize_uint16', '*',
+ os.path.join(tfjs_output_dir, 'model.json'), sharded_model_dir
+ ])
+ process.communicate()
+ self.assertEqual(0, process.returncode)
+
+ # 4. Check the quantized weight file and its size.
+ weight_files = sorted(
+ glob.glob(os.path.join(sharded_model_dir, 'group*.bin')))
+ self.assertEqual(len(weight_files), 1)
+ weight_file_size = os.path.getsize(weight_files[0])
+ # The size of the weight file should reflect the uint16 quantization.
+ self.assertEqual(weight_file_size, total_weight_bytes // 2)
+
def testConvertTfjsLayersModelToTfjsGraphModel(self):
with tf.Graph().as_default(), tf.compat.v1.Session():
# 1. Create a model for testing.
diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts
index 1387152e2c0..d83d4e3646b 100644
--- a/tfjs-core/src/io/io_utils.ts
+++ b/tfjs-core/src/io/io_utils.ts
@@ -114,6 +114,7 @@ export function decodeWeights(
buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap {
// TODO(adarob, cais): Support quantization.
const out: NamedTensorMap = {};
+ let float16Decode: (buffer: Uint16Array) => Float32Array | undefined;
let offset = 0;
for (const spec of specs) {
const name = spec.name;
@@ -124,11 +125,26 @@ export function decodeWeights(
if ('quantization' in spec) {
const quantization = spec.quantization;
- if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
+ if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
+ if (!('min' in quantization && 'scale' in quantization)) {
+ throw new Error(
+ `Weight ${spec.name} with quantization ${quantization.dtype} ` +
+ `doesn't have corresponding metadata min and scale.`
+ );
+ }
+ } else if (quantization.dtype === 'float16') {
+ if (dtype !== 'float32') {
+ throw new Error(
+ `Weight ${spec.name} is quantized with ${quantization.dtype} ` +
+ `which only supports weights of type float32 not ${dtype}.`
+ );
+ }
+ } else {
throw new Error(
`Weight ${spec.name} has unknown ` +
`quantization dtype ${quantization.dtype}. ` +
- `Supported quantization dtypes are: 'uint8' and 'uint16'.`);
+ `Supported quantization dtypes are: ` +
+ `'uint8', 'uint16', and 'float16'.`);
}
const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype];
const byteBuffer =
@@ -137,12 +153,30 @@ export function decodeWeights(
new Uint8Array(byteBuffer) :
new Uint16Array(byteBuffer);
if (dtype === 'float32') {
- values = new Float32Array(quantizedArray.length);
- for (let i = 0; i < quantizedArray.length; i++) {
- const v = quantizedArray[i];
- values[i] = v * quantization.scale + quantization.min;
+ if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') {
+ values = new Float32Array(quantizedArray.length);
+ for (let i = 0; i < quantizedArray.length; i++) {
+ const v = quantizedArray[i];
+ values[i] = v * quantization.scale + quantization.min;
+ }
+ } else if (quantization.dtype === 'float16') {
+ if (float16Decode === undefined) {
+ float16Decode = getFloat16Decoder();
+ }
+ values = float16Decode(quantizedArray as Uint16Array);
+ } else {
+ throw new Error(
+ `Unsupported quantization type ${quantization.dtype} ` +
+ `for weight type float32.`
+ );
}
} else if (dtype === 'int32') {
+ if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') {
+ throw new Error(
+ `Unsupported quantization type ${quantization.dtype} ` +
+ `for weight type int32.`
+ );
+ }
values = new Int32Array(quantizedArray.length);
for (let i = 0; i < quantizedArray.length; i++) {
const v = quantizedArray[i];
@@ -363,3 +397,106 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts):
modelArtifacts.weightData.byteLength,
};
}
+
+/**
+ * Computes mantisa table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 2048 mantissa lookup values.
+ */
+function computeFloat16MantisaTable(): Uint32Array {
+ const convertMantissa = (i: number): number => {
+ let m = i << 13;
+ let e = 0;
+
+ while ((m & 0x00800000) === 0) {
+ e -= 0x00800000;
+ m <<= 1;
+ }
+ m &= ~0x00800000;
+ e += 0x38800000;
+
+ return m | e;
+ };
+
+ const mantisaTable = new Uint32Array(2048);
+
+ mantisaTable[0] = 0;
+ for (let i = 1; i < 1024; i++) {
+ mantisaTable[i] = convertMantissa(i);
+ }
+ for (let i = 1024; i < 2048; i++) {
+ mantisaTable[i] = 0x38000000 + ((i - 1024) << 13);
+ }
+
+ return mantisaTable;
+}
+
+/**
+ * Computes exponent table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 64 exponent lookup values.
+ */
+function computeFloat16ExponentTable(): Uint32Array {
+ const exponentTable = new Uint32Array(64);
+
+ exponentTable[0] = 0;
+ exponentTable[31] = 0x47800000;
+ exponentTable[32] = 0x80000000;
+ exponentTable[63] = 0xc7800000;
+ for (let i = 1; i < 31; i++) {
+ exponentTable[i] = i << 23;
+ }
+ for (let i = 33; i < 63; i++) {
+ exponentTable[i] = 0x80000000 + ((i - 32) << 23);
+ }
+
+ return exponentTable;
+}
+
+/**
+ * Computes offset table for casting Float16 to Float32
+ * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+ *
+ * @returns Uint32Array, 6d offset values.
+ */
+function computeFloat16OffsetTable(): Uint32Array {
+ const offsetTable = new Uint32Array(64);
+
+ for (let i = 0; i < 64; i++) {
+ offsetTable[i] = 1024;
+ }
+ offsetTable[0] = offsetTable[32] = 0;
+
+ return offsetTable;
+}
+
+/**
+ * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values
+ * to a Float32Array.
+ *
+ * @returns Function (buffer: Uint16Array) => Float32Array which decodes
+ * the Uint16Array of Float16 bytes to a Float32Array.
+ */
+export function getFloat16Decoder(): (buffer: Uint16Array) => Float32Array {
+ // Algorithm is based off of http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf
+
+ // Cache lookup tables
+ const mantisaTable = computeFloat16MantisaTable();
+ const exponentTable = computeFloat16ExponentTable();
+ const offsetTable = computeFloat16OffsetTable();
+
+ return (quantizedArray: Uint16Array) => {
+ const buffer = new ArrayBuffer(4 * quantizedArray.length);
+ const bufferUint32View = new Uint32Array(buffer);
+ for (let index = 0; index < quantizedArray.length; index++) {
+ const float16Bits = quantizedArray[index];
+ const float32Bits =
+ mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] +
+ exponentTable[float16Bits >> 10];
+ bufferUint32View[index] = float32Bits;
+ }
+ return new Float32Array(buffer);
+ };
+}
diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts
index 9ee1f35915a..6656ef43aa3 100644
--- a/tfjs-core/src/io/io_utils_test.ts
+++ b/tfjs-core/src/io/io_utils_test.ts
@@ -23,7 +23,7 @@ import {expectArraysEqual} from '../test_util';
import {expectArraysClose} from '../test_util';
import {encodeString} from '../util';
-import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils';
+import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength, getFloat16Decoder} from './io_utils';
import {WeightsManifestEntry} from './types';
describe('concatenateTypedArrays', () => {
@@ -565,6 +565,22 @@ describeWithFlags('decodeWeights', {}, () => {
expect(weight1.shape).toEqual([3]);
expect(weight1.dtype).toEqual('int32');
});
+ it('support quantization float16 weights', async () => {
+ const manifestSpecs: WeightsManifestEntry[] = [
+ {
+ name: 'weight0',
+ dtype: 'float32',
+ shape: [3],
+ quantization: { dtype: 'float16' },
+ },
+ ];
+ const data = new Uint16Array([13312, 14336, 14848]);
+ const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs);
+ const weight0 = decoded['weight0'];
+ expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]);
+ expect(weight0.shape).toEqual([3]);
+ expect(weight0.dtype).toEqual('float32');
+ });
});
describe('stringByteLength', () => {
@@ -654,3 +670,62 @@ describe('basename', () => {
expect(basename('foo/bar/baz//')).toEqual('baz');
});
});
+
+describe('float16', () => {
+ it('decodes NaN to float32 NaN', () => {
+ const decoder = getFloat16Decoder();
+ const float16NaN = 0x00007e00;
+ const buffer = new Uint16Array([float16NaN]);
+ const f32 = decoder(buffer);
+ expect(f32).toEqual(new Float32Array([NaN]));
+ });
+
+ it('decodes ±Infinity to float32 ±Infinity', () => {
+ const decoder = getFloat16Decoder();
+ const positiveInfinity = 0x00007c00;
+ const negativeInfinity = 0xfffffc00;
+ const buffer = new Uint16Array([positiveInfinity, negativeInfinity]);
+ const f32 = decoder(buffer);
+ expect(f32).toEqual(new Float32Array([Infinity, -Infinity]));
+ });
+
+ it('decodes ±0 to float32 ±0', () => {
+ const decoder = getFloat16Decoder();
+ const positiveZero = 0x00000000;
+ const negativeZero = 0xffff8000;
+ const buffer = new Uint16Array([positiveZero, negativeZero]);
+ const f32 = decoder(buffer);
+ expect(f32).toEqual(new Float32Array([0.0, -0.0]));
+ });
+
+ it('decodes -Infinity on underflow', () => {
+ const decoder = getFloat16Decoder();
+ const minVal = 0xfffffbff;
+ const buffer = new Uint16Array([minVal + 1]);
+ const f32 = decoder(buffer);
+ expect(f32).toEqual(new Float32Array([-Infinity]));
+ });
+
+ it('decodes +Infinity on overflow', () => {
+ const decoder = getFloat16Decoder();
+ const maxVal = 0x00007bff;
+ const buffer = new Uint16Array([maxVal + 1]);
+ const f32 = decoder(buffer);
+ expect(f32).toEqual(new Float32Array([Infinity]));
+ });
+
+ it('decodes interpretable float16 to float32', () => {
+ const decoder = getFloat16Decoder();
+ const buffer = new Uint16Array([
+ 0x00003400,
+ 0x00003800,
+ 0x00003A00,
+ 0x00003555
+ ]);
+ const f32 = decoder(buffer);
+ expect(f32[0]).toBeCloseTo(0.25);
+ expect(f32[1]).toBeCloseTo(0.5);
+ expect(f32[2]).toBeCloseTo(0.75);
+ expect(f32[3]).toBeCloseTo(0.333);
+ });
+});
diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts
index f4ee6ff0939..63c5e3deff0 100644
--- a/tfjs-core/src/io/types.ts
+++ b/tfjs-core/src/io/types.ts
@@ -22,6 +22,7 @@
*/
export const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = {
'float32': 4,
+ 'float16': 2,
'int32': 4,
'uint16': 2,
'uint8': 1,
@@ -102,9 +103,9 @@ export declare interface WeightsManifestEntry {
* Information for dequantization of the weight.
*/
quantization?: {
- scale: number, // The scaling constant to multiply by.
- min: number, // The (possibly nudged) minimum weight to add.
- dtype: 'uint16'|'uint8' // The dtype of the quantized weights.
+ scale?: number, // The scaling constant to multiply by.
+ min?: number, // The (possibly nudged) minimum weight to add.
+ dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights.
};
}