From 0b30984bd67e2b7512be4f09354217f94cee4805 Mon Sep 17 00:00:00 2001 From: Jesse Farebrother Date: Fri, 22 May 2020 17:53:38 -0600 Subject: [PATCH 1/2] Converter float16 quantization support --- tfjs-converter/README.md | 32 ++- .../python/tensorflowjs/converters/common.py | 3 + .../tensorflowjs/converters/converter.py | 136 ++++++---- .../tensorflowjs/converters/converter_test.py | 7 +- .../converters/keras_h5_conversion.py | 18 +- .../tf_saved_model_conversion_v2.py | 68 +++-- .../python/tensorflowjs/converters/wizard.py | 46 +++- .../tensorflowjs/converters/wizard_test.py | 22 +- .../python/tensorflowjs/quantization.py | 149 +++++++++-- .../python/tensorflowjs/quantization_test.py | 246 ++++++++++++++---- .../python/tensorflowjs/read_weights.py | 5 +- .../python/tensorflowjs/read_weights_test.py | 25 +- .../python/tensorflowjs/write_weights.py | 43 +-- .../python/tensorflowjs/write_weights_test.py | 28 +- tfjs-converter/python/test_pip_package.py | 43 ++- 15 files changed, 657 insertions(+), 214 deletions(-) diff --git a/tfjs-converter/README.md b/tfjs-converter/README.md index 0f810052dd2..40fa234ce77 100644 --- a/tfjs-converter/README.md +++ b/tfjs-converter/README.md @@ -154,7 +154,10 @@ saved a tf.keras model in the SavedModel format. |`--saved_model_tags` | Only applicable to SavedModel conversion. Tags of the MetaGraphDef to load, in comma separated format. Defaults to `serve`.| |`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.| |`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.| -|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.| +|`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.| +|`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | +|`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | +|`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. | |`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).| |`--output_node_names`| Only applicable to Frozen Model. The names of the output nodes, separated by commas.| @@ -216,7 +219,7 @@ purposes: tensorflowjs_converter \ --input_format tfjs_layers_model \ --output_format tfjs_layers_model \ - --quantization_bytes 2 \ + --quantize_uint16 \ original_model/model.json quantized_model/ ``` @@ -380,18 +383,33 @@ browser to cache them automatically. If the model architecture is less than 4MB __4. Can I quantize the weights over the wire?__ -Yes, you can use the --quantization_bytes option to compress int32/float32 to 1 -or 2 bytes. Here is -an example of 8-bit quantization: +Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress +weights with 1 byte integer quantization (`uint8`) or 2 byte integer +(`uint16`)/float (`float16`) quantization. +Quantizing to float16 may provide better accuracy over +2 byte affine integer scaling (`uint16`). 1-byte affine quantization, +i.e., `uint8` provides a 4x size reduction at the cost of accuracy. +For example, we can quantize our MobileNet model using float16 quantization: ``` -tensorflowjs_converter \ +tensorflowjs_converter + --quantize_float16 \ --input_format=tf_hub \ - --quantization_bytes=1 'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \ /mobilenet/web_model ``` +You can also quantize specific weights as well as weight groupings using +a wildcard replacement. For example, +``` +tensorflowjs_converter + --quantize_float16="conv/*/weights" +``` +which will quantize all weights that match the pattern conv/*/weights. +This will exclude biases and any weights that don't begin with conv/. +This can be a powerful tool to reduce model size while trying to maximize +performance. + __5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__ The time of first call also includes the compilation time of WebGL shader diff --git a/tfjs-converter/python/tensorflowjs/converters/common.py b/tfjs-converter/python/tensorflowjs/converters/common.py index 5aa062d1fe2..2775d5aedb6 100644 --- a/tfjs-converter/python/tensorflowjs/converters/common.py +++ b/tfjs-converter/python/tensorflowjs/converters/common.py @@ -50,6 +50,9 @@ SIGNATURE_NAME = 'signature_name' SAVED_MODEL_TAGS = 'saved_model_tags' QUANTIZATION_BYTES = 'quantization_bytes' +QUANTIZATION_TYPE_FLOAT16 = 'quantize_float16' +QUANTIZATION_TYPE_UINT8 = 'quantize_uint8' +QUANTIZATION_TYPE_UINT16 = 'quantize_uint16' SPLIT_WEIGHTS_BY_LAYER = 'split_weights_by_layer' VERSION = 'version' SKIP_OP_CHECK = 'skip_op_check' diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index 3858e666591..ed682ac1b73 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -26,7 +26,6 @@ import tempfile import h5py -import numpy as np import tensorflow.compat.v1 as tf1 import tensorflow.compat.v2 as tf @@ -39,7 +38,7 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion( - h5_path, output_dir=None, quantization_dtype=None, + h5_path, output_dir=None, quantization_dtype_map=None, split_weights_by_layer=False, weight_shard_size_bytes=1024 * 1024 * 4): """Converts a Keras HDF5 saved-model file to TensorFlow.js format. @@ -56,8 +55,8 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion( output_dir: Output directory to which the TensorFlow.js-format model JSON file and weights files will be written. If the directory does not exist, it will be created. - quantization_dtype: The quantized data type to store the weights in - (Default: `None`). + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. split_weights_by_layer: Whether to split the weights into separate weight groups (corresponding to separate binary weight files) layer by layer (Default: `False`). @@ -94,7 +93,7 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion( if not os.path.isdir(output_dir): os.makedirs(output_dir) conversion.write_artifacts( - model_json, groups, output_dir, quantization_dtype, + model_json, groups, output_dir, quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes) return model_json, groups @@ -102,7 +101,7 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion( def dispatch_keras_h5_to_tfjs_graph_model_conversion( h5_path, output_dir=None, - quantization_dtype=None, + quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4, @@ -115,8 +114,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( keras or tf.keras. output_dir: The destination to which the tfjs GraphModel artifacts will be written. - quantization_dtype: The quantized data type to store the weights in - (Default: `None`). + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -140,7 +139,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( temp_savedmodel_dir, output_dir, signature_def='serving_default', saved_model_tags='serve', - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, @@ -151,7 +150,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion( def dispatch_keras_saved_model_to_tensorflowjs_conversion( - keras_saved_model_path, output_dir, quantization_dtype=None, + keras_saved_model_path, output_dir, quantization_dtype_map=None, split_weights_by_layer=False, weight_shard_size_bytes=1024 * 1024 * 4): """Converts keras model saved in the SavedModel format to tfjs format. @@ -168,8 +167,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion( output_dir: Output directory to which the TensorFlow.js-format model JSON file and weights files will be written. If the directory does not exist, it will be created. - quantization_dtype: The quantized data type to store the weights in - (Default: `None`). + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. split_weights_by_layer: Whether to split the weights into separate weight groups (corresponding to separate binary weight files) layer by layer (Default: `False`). @@ -187,7 +186,7 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion( dispatch_keras_h5_to_tfjs_layers_model_conversion( temp_h5_path, output_dir, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, split_weights_by_layer=split_weights_by_layer, weight_shard_size_bytes=weight_shard_size_bytes) @@ -271,7 +270,7 @@ def dispatch_tensorflowjs_to_keras_saved_model_conversion( def dispatch_tensorflowjs_to_tensorflowjs_conversion( config_json_path, output_dir_path, - quantization_dtype=None, + quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4): """Converts a Keras Model from tensorflowjs format to H5. @@ -280,8 +279,8 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion( topology and weights manifest, in tensorflowjs format. output_dir_path: Path to output directory in which the result of the conversion will be saved. - quantization_dtype: The quantized data type to store the weights in - (Default: `None`). + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. @@ -318,7 +317,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion( with tf.Graph().as_default(), tf.compat.v1.Session(): dispatch_keras_h5_to_tfjs_layers_model_conversion( temp_h5_path, output_dir_path, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes) # TODO(cais): Support weight quantization. @@ -329,7 +328,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion( def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( config_json_path, output_dir_path, - quantization_dtype=None, + quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4): @@ -343,8 +342,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( topology and weights manifest, in tensorflowjs format. output_dir_path: Path to output directory in which the result of the conversion will be saved. - quantization_dtype: The quantized data type to store the weights in - (Default: `None`). + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to allow unsupported debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -380,7 +379,7 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion( model.save(temp_h5_path) dispatch_keras_h5_to_tfjs_graph_model_conversion( temp_h5_path, output_dir_path, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) @@ -416,16 +415,32 @@ def _standardize_input_output_formats(input_format, output_format): return (input_format, output_format) +def _parse_quantization_dtype_map(float16, uint8, uint16, quantization_bytes): + quantization_dtype_map = {} -def _parse_quantization_bytes(quantization_bytes): - if quantization_bytes is None: - return None - elif quantization_bytes == 1: - return np.uint8 - elif quantization_bytes == 2: - return np.uint16 - else: - raise ValueError('Unsupported quantization bytes: %s' % quantization_bytes) + if quantization_bytes: + print( + 'Warning: --quantization_bytes will be deprecated in a future release\n' + 'Please consider using --quantize_uint8, --quantize_uint16, ' + '--quantize_float16.', file=sys.stderr) + if float16 is not None or uint8 is not None or uint16 is not None: + raise ValueError( + '--quantization_bytes cannot be used with the new quantization flags') + + dtype = quantization.QUANTIZATION_BYTES_TO_DTYPES[quantization_bytes] + quantization_dtype_map[dtype] = True + + if float16 is not None: + quantization_dtype_map[quantization.QUANTIZATION_DTYPE_FLOAT16] = \ + float16.split(',') if isinstance(float16, str) else float16 + if uint8 is not None: + quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT8] = \ + uint8.split(',') if isinstance(uint8, str) else uint8 + if uint16 is not None: + quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT16] = \ + uint16.split(',') if isinstance(uint16, str) else uint16 + + return quantization_dtype_map def get_arg_parser(): """ @@ -488,13 +503,43 @@ def get_arg_parser(): help='Tags of the MetaGraphDef to load, in comma separated string ' 'format. Defaults to "serve". Applicable only if input format is ' '"tf_saved_model".') + parser.add_argument( + '--%s' % common.QUANTIZATION_TYPE_FLOAT16, + type=str, + default=None, + const=True, + nargs='?', + help='Comma separated list of node names to apply float16 quantization. ' + 'You can also use wildcard symbol (*) to apply quantization to multiple ' + 'nodes (e.g., conv/*/weights). When the flag is provided without any ' + 'nodes the default behavior will match all nodes.') + parser.add_argument( + '--%s' % common.QUANTIZATION_TYPE_UINT8, + type=str, + default=None, + const=True, + nargs='?', + help='Comma separated list of node names to apply 1-byte affine ' + 'quantization. You can also use wildcard symbol (*) to apply ' + 'quantization to multiple nodes (e.g., conv/*/weights). When the flag is ' + 'provided without any nodes the default behavior will match all nodes.') + parser.add_argument( + '--%s' % common.QUANTIZATION_TYPE_UINT16, + type=str, + default=None, + const=True, + nargs='?', + help='Comma separated list of node names to apply 2-byte affine ' + 'quantization. You can also use wildcard symbol (*) to apply ' + 'quantization to multiple nodes (e.g., conv/*/weights). When the flag is ' + 'provided without any nodes the default behavior will match all nodes.') parser.add_argument( '--%s' % common.QUANTIZATION_BYTES, type=int, choices=set(quantization.QUANTIZATION_BYTES_TO_DTYPES.keys()), - help='How many bytes to optionally quantize/compress the weights to. 1- ' - 'and 2-byte quantizaton is supported. The default (unquantized) size is ' - '4 bytes.') + help='(Deprecated) How many bytes to optionally quantize/compress the ' + 'weights to. 1- and 2-byte quantizaton is supported. The default ' + '(unquantized) size is 4 bytes.') parser.add_argument( '--%s' % common.SPLIT_WEIGHTS_BY_LAYER, action='store_true', @@ -574,9 +619,12 @@ def convert(arguments): 'but got %s' % args.weight_shard_size_bytes) weight_shard_size_bytes = args.weight_shard_size_bytes - quantization_dtype = ( - quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes] - if args.quantization_bytes else None) + quantization_dtype_map = _parse_quantization_dtype_map( + args.quantize_float16, + args.quantize_uint8, + args.quantize_uint16, + args.quantization_bytes + ) if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL): raise ValueError( @@ -601,14 +649,14 @@ def convert(arguments): output_format == common.TFJS_LAYERS_MODEL): dispatch_keras_h5_to_tfjs_layers_model_conversion( args.input_path, output_dir=args.output_path, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, split_weights_by_layer=args.split_weights_by_layer, weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.KERAS_MODEL and output_format == common.TFJS_GRAPH_MODEL): dispatch_keras_h5_to_tfjs_graph_model_conversion( args.input_path, output_dir=args.output_path, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=args.skip_op_check, strip_debug_ops=args.strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, @@ -617,7 +665,7 @@ def convert(arguments): output_format == common.TFJS_LAYERS_MODEL): dispatch_keras_saved_model_to_tensorflowjs_conversion( args.input_path, args.output_path, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, split_weights_by_layer=args.split_weights_by_layer, weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TF_SAVED_MODEL and @@ -626,7 +674,7 @@ def convert(arguments): args.input_path, args.output_path, signature_def=args.signature_name, saved_model_tags=args.saved_model_tags, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=args.skip_op_check, strip_debug_ops=args.strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, @@ -637,7 +685,7 @@ def convert(arguments): args.input_path, args.output_path, signature=args.signature_name, saved_model_tags=args.saved_model_tags, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=args.skip_op_check, strip_debug_ops=args.strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, @@ -654,13 +702,13 @@ def convert(arguments): output_format == common.TFJS_LAYERS_MODEL): dispatch_tensorflowjs_to_tensorflowjs_conversion( args.input_path, args.output_path, - quantization_dtype=_parse_quantization_bytes(args.quantization_bytes), + quantization_dtype_map=quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes) elif (input_format == common.TFJS_LAYERS_MODEL and output_format == common.TFJS_GRAPH_MODEL): dispatch_tfjs_layers_model_to_tfjs_graph_conversion( args.input_path, args.output_path, - quantization_dtype=_parse_quantization_bytes(args.quantization_bytes), + quantization_dtype_map=quantization_dtype_map, skip_op_check=args.skip_op_check, strip_debug_ops=args.strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) @@ -668,7 +716,7 @@ def convert(arguments): output_format == common.TFJS_GRAPH_MODEL): tf_saved_model_conversion_v2.convert_tf_frozen_model( args.input_path, args.output_node_names, args.output_path, - quantization_dtype=_parse_quantization_bytes(args.quantization_bytes), + quantization_dtype_map=quantization_dtype_map, skip_op_check=args.skip_op_check, strip_debug_ops=args.strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) diff --git a/tfjs-converter/python/tensorflowjs/converters/converter_test.py b/tfjs-converter/python/tensorflowjs/converters/converter_test.py index 3a8ecaf20f4..1430bc549cc 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter_test.py @@ -564,7 +564,8 @@ def testConvertTfKerasSequentialSavedAsSavedModelWithQuantization(self): # Convert the keras SavedModel to tfjs format. tfjs_output_dir = os.path.join(self._tmp_dir, 'tfjs') converter.dispatch_keras_saved_model_to_tensorflowjs_conversion( - self._tmp_dir, tfjs_output_dir, quantization_dtype=np.uint16) + self._tmp_dir, tfjs_output_dir, + quantization_dtype_map={'uint16': '*'}) # Verify the size of the weight file. weight_path = glob.glob(os.path.join(tfjs_output_dir, 'group*-*'))[0] @@ -688,7 +689,7 @@ def testConvertTfjsLayersModelWithUint16Quantization(self): sharded_model_path = os.path.join(self._tmp_dir, 'sharded_model') converter.dispatch_tensorflowjs_to_tensorflowjs_conversion( os.path.join(tfjs_output_dir, 'model.json'), sharded_model_path, - quantization_dtype=np.uint16, + quantization_dtype_map={'uint16': '*'}, weight_shard_size_bytes=weight_shard_size_bytes) # Check the number of quantized files and their sizes. @@ -723,7 +724,7 @@ def testConvertTfjsLayersModelWithUint8Quantization(self): sharded_model_path = os.path.join(self._tmp_dir, 'sharded_model') converter.dispatch_tensorflowjs_to_tensorflowjs_conversion( os.path.join(tfjs_output_dir, 'model.json'), sharded_model_path, - quantization_dtype=np.uint8, + quantization_dtype_map={'uint8': '*'}, weight_shard_size_bytes=weight_shard_size_bytes) # Check the number of quantized files and their sizes. diff --git a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py index c25a02869aa..7a37de2d7b5 100644 --- a/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py +++ b/tfjs-converter/python/tensorflowjs/converters/keras_h5_conversion.py @@ -259,7 +259,7 @@ def _get_generated_by(topology): def write_artifacts(topology, weights, output_dir, - quantization_dtype=None, + quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4): """Writes weights and topology to the output_dir. @@ -269,8 +269,9 @@ def write_artifacts(topology, topology: a JSON dictionary, representing the Keras config. weights: an array of weight groups (as defined in tfjs write_weights). output_dir: the directory to hold all the contents. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: (Optional) A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. """ @@ -297,7 +298,7 @@ def write_artifacts(topology, model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None weights_manifest = write_weights.write_weights( weights, output_dir, write_manifest=False, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, shard_size_bytes=weight_shard_size_bytes) assert isinstance(weights_manifest, list) model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest @@ -308,7 +309,7 @@ def write_artifacts(topology, json.dump(model_json, f) -def save_keras_model(model, artifacts_dir, quantization_dtype=None, +def save_keras_model(model, artifacts_dir, quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4): r"""Save a Keras model and its weights in TensorFlow.js format. @@ -326,8 +327,9 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None, - files containing weight values in groups, with the file name pattern group(\d+)-shard(\d+)of(\d+). If the directory does not exist, this function will attempt to create it. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: (Optional) A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. @@ -344,6 +346,6 @@ def save_keras_model(model, artifacts_dir, quantization_dtype=None, os.makedirs(artifacts_dir) write_artifacts( topology_json, weight_groups, artifacts_dir, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes) os.remove(temp_h5_path) diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 6f2377447c2..0041929adb7 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -110,8 +110,8 @@ def _run_grappler(config, graph_def, graph, signature_def): config, meta_graph, cluster=get_cluster()) def optimize_graph(graph, signature_def, output_graph, - tf_version, quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False, + tf_version, quantization_dtype_map=None, + skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4): """Takes a Python Graph object and optimizes the graph. @@ -120,8 +120,9 @@ def optimize_graph(graph, signature_def, output_graph, signature_def: the SignatureDef of the inference graph. output_graph: The location of the output graph. tf_version: Tensorflow version of the input graph. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -192,7 +193,7 @@ def optimize_graph(graph, signature_def, output_graph, extract_weights( optimized_graph, output_graph, tf_version, - signature_def, quantization_dtype, weight_shard_size_bytes) + signature_def, quantization_dtype_map, weight_shard_size_bytes) return optimize_graph @@ -230,7 +231,7 @@ def extract_weights(graph_def, output_graph, tf_version, signature_def, - quantization_dtype=None, + quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4): """Takes a Python GraphDef object and extract the weights. @@ -239,8 +240,10 @@ def extract_weights(graph_def, the model topology. tf_version: Tensorflow version of the input graph. signature_def: the SignatureDef of the inference graph. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + compression. Only np.uint8 and np.uint16 are supported. + supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. """ @@ -260,7 +263,7 @@ def extract_weights(graph_def, [global_manifest + function_manifests], output_graph, tf_version, signature_def, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, weight_shard_size_bytes=weight_shard_size_bytes) @@ -269,7 +272,7 @@ def write_artifacts(topology, output_graph, tf_version, signature_def, - quantization_dtype=None, + quantization_dtype_map=None, weight_shard_size_bytes=1024 * 1024 * 4): """Writes weights and topology to the output_dir. @@ -282,8 +285,9 @@ def write_artifacts(topology, output_graph: the output file name to hold all the contents. tf_version: Tensorflow version of the input graph. signature_def: the SignatureDef of the inference graph. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. weight_shard_size_bytes: Shard size (in bytes) of the weight files. The size of each weight file will be <= this value. """ @@ -300,7 +304,7 @@ def write_artifacts(topology, model_json[common.ARTIFACT_MODEL_TOPOLOGY_KEY] = topology or None weights_manifest = write_weights.write_weights( weights, os.path.dirname(output_graph), write_manifest=False, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, shard_size_bytes=weight_shard_size_bytes) assert isinstance(weights_manifest, list) model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest @@ -378,7 +382,8 @@ def _build_signature_def(frozen_graph, input_nodes, output_nodes): def convert_tf_frozen_model(frozen_model_path, output_node_names, - output_dir, quantization_dtype=None, + output_dir, + quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4): @@ -392,8 +397,9 @@ def convert_tf_frozen_model(frozen_model_path, will consist of - a file named 'model.json' - possibly sharded binary weight files. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -410,7 +416,7 @@ def convert_tf_frozen_model(frozen_model_path, optimize_graph(graph, signature, output_graph, tf.__version__, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) @@ -418,7 +424,7 @@ def convert_tf_frozen_model(frozen_model_path, def convert_tf_saved_model(saved_model_dir, output_dir, signature_def='serving_default', saved_model_tags='serve', - quantization_dtype=None, + quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4, @@ -438,8 +444,9 @@ def convert_tf_saved_model(saved_model_dir, signature_def: string Tagset of the SignatureDef to load. Defaults to 'serving_default'. saved_model_tags: tags of the GraphDef to load. Defaults to 'serve'. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -481,7 +488,7 @@ def convert_tf_saved_model(saved_model_dir, optimize_graph(frozen_graph, signature, output_graph, model.tensorflow_version, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) @@ -533,7 +540,7 @@ def load_and_initialize_hub_module(module_path, signature='default'): def convert_tf_hub_module_v1(module_path, output_dir, - signature='default', quantization_dtype=None, + signature='default', quantization_dtype_map=None, skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4): """Freeze the TF-Hub module and check compatibility with Tensorflow.js. @@ -548,6 +555,9 @@ def convert_tf_hub_module_v1(module_path, output_dir, - a file named 'model.json' - possibly sharded binary weight files. signature: string Signature to load. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -589,7 +599,7 @@ def convert_tf_hub_module_v1(module_path, output_dir, optimize_graph(frozen_graph, signature, output_graph, tf.__version__, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes) @@ -601,8 +611,8 @@ def convert_tf_hub_module_v1(module_path, output_dir, def convert_tf_hub_module(module_handle, output_dir, signature='default', saved_model_tags='serve', - quantization_dtype=None, skip_op_check=False, - strip_debug_ops=False, + quantization_dtype_map=None, + skip_op_check=False, strip_debug_ops=False, weight_shard_size_bytes=1024 * 1024 * 4, control_flow_v2=False): """Conversion for TF Hub modules V1 and V2. @@ -617,6 +627,9 @@ def convert_tf_hub_module(module_handle, output_dir, - possibly sharded binary weight files. signature: string Signature to load. saved_model_tags: tags of the GraphDef to load. Defaults to ''. + quantization_dtype_map: A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. skip_op_check: Bool whether to skip the op check. strip_debug_ops: Bool whether to strip debug ops. weight_shard_size_bytes: Shard size (in bytes) of the weight files. @@ -630,7 +643,8 @@ def convert_tf_hub_module(module_handle, output_dir, if tf.io.gfile.exists(os.path.join(module_path, _HUB_V1_MODULE_PB)): print("Loading the module using TF 1.X interface from %s." % module_path) convert_tf_hub_module_v1(module_path, output_dir, signature, - quantization_dtype, skip_op_check, strip_debug_ops, + quantization_dtype_map, + skip_op_check, strip_debug_ops, weight_shard_size_bytes) else: print("Loading the module using TF 2.X interface from %s." % module_path) @@ -640,7 +654,7 @@ def convert_tf_hub_module(module_handle, output_dir, output_dir=output_dir, signature_def=signature, saved_model_tags=saved_model_tags, - quantization_dtype=quantization_dtype, + quantization_dtype_map=quantization_dtype_map, skip_op_check=skip_op_check, strip_debug_ops=strip_debug_ops, weight_shard_size_bytes=weight_shard_size_bytes, diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard.py b/tfjs-converter/python/tensorflowjs/converters/wizard.py index cdfea64a0c7..9e23f99611c 100644 --- a/tfjs-converter/python/tensorflowjs/converters/wizard.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard.py @@ -215,7 +215,7 @@ def generate_arguments(params): """ args = [] not_param_list = [common.INPUT_PATH, common.OUTPUT_PATH, - 'overwrite_output_path'] + 'overwrite_output_path', 'quantize'] no_false_param = [common.SPLIT_WEIGHTS_BY_LAYER, common.SKIP_OP_CHECK] for key, value in sorted(params.items()): if key not in not_param_list and value is not None: @@ -463,20 +463,54 @@ def run(dryrun): }, { 'type': 'list', - 'name': common.QUANTIZATION_BYTES, + 'name': 'quantize', 'message': 'Do you want to compress the model? ' '(this will decrease the model precision.)', 'choices': [{ 'name': 'No compression (Higher accuracy)', 'value': None }, { - 'name': '2x compression (Accuracy/size trade-off)', - 'value': 2 + 'name': 'float16 quantization ' + '(2x smaller, Minimal accuracy loss)', + 'value': 'float16' }, { - 'name': '4x compression (Smaller size)', - 'value': 1 + 'name': 'uint16 affine quantization (2x smaller, Accuracy loss)', + 'value': 'uint16' + }, { + 'name': 'uint8 affine quantization (4x smaller, Accuracy loss)', + 'value': 'uint8' }] }, + { + 'type': 'input', + 'name': common.QUANTIZATION_TYPE_FLOAT16, + 'message': 'Please enter the layers to apply float16 quantization ' + '(2x smaller, minimal accuracy tradeoff).\n' + 'Supports wildcard expansion with *, e.g., conv/*/weights', + 'default': '*', + 'when': lambda answers: + value_in_list(answers, 'quantize', ('float16')) + }, + { + 'type': 'input', + 'name': common.QUANTIZATION_TYPE_UINT8, + 'message': 'Please enter the layers to apply affine 1-byte integer ' + 'quantization (4x smaller, accuracy tradeoff).\n' + 'Supports wildcard expansion with *, e.g., conv/*/weights', + 'default': '*', + 'when': lambda answers: + value_in_list(answers, 'quantize', ('uint8')) + }, + { + 'type': 'input', + 'name': common.QUANTIZATION_TYPE_UINT16, + 'message': 'Please enter the layers to apply affine 2-byte integer ' + 'quantization (2x smaller, accuracy tradeoff).\n' + 'Supports wildcard expansion with *, e.g., conv/*/weights', + 'default': '*', + 'when': lambda answers: + value_in_list(answers, 'quantize', ('uint16')) + }, { 'type': 'input', 'name': common.WEIGHT_SHARD_SIZE_BYTES, diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py index c50545cde0d..d724e79401f 100644 --- a/tfjs-converter/python/tensorflowjs/converters/wizard_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py @@ -183,7 +183,7 @@ def testGenerateCommandForSavedModel(self): 'input_path': 'tmp/saved_model', 'saved_model_tags': 'test', 'signature_name': 'test_default', - 'quantization_bytes': 2, + 'quantize_float16': 'conv/*/weights', 'weight_shard_size_bytes': '4194304', 'skip_op_check': False, 'strip_debug_ops': True, @@ -192,8 +192,10 @@ def testGenerateCommandForSavedModel(self): self.assertEqual(['--control_flow_v2=True', '--input_format=tf_saved_model', - '--quantization_bytes=2', '--saved_model_tags=test', - '--signature_name=test_default', '--strip_debug_ops=True', + '--quantize_float16=conv/*/weights', + '--saved_model_tags=test', + '--signature_name=test_default', + '--strip_debug_ops=True', '--weight_shard_size_bytes=4194304', 'tmp/saved_model', 'tmp/web_model'], wizard.generate_arguments(options)) @@ -205,7 +207,7 @@ def testGenerateCommandForKerasSavedModel(self): 'saved_model_tags': 'test', 'signature_name': 'test_default', 'weight_shard_size_bytes': '100', - 'quantization_bytes': 1, + 'quantize_float16': 'conv/*/weights', 'skip_op_check': True, 'strip_debug_ops': False, 'control_flow_v2': False, @@ -214,7 +216,8 @@ def testGenerateCommandForKerasSavedModel(self): self.assertEqual(['--control_flow_v2=False', '--input_format=tf_keras_saved_model', '--output_format=tfjs_layers_model', - '--quantization_bytes=1', '--saved_model_tags=test', + '--quantize_float16=conv/*/weights', + '--saved_model_tags=test', '--signature_name=test_default', '--skip_op_check', '--strip_debug_ops=False', '--weight_shard_size_bytes=100', @@ -225,10 +228,11 @@ def testGenerateCommandForKerasModel(self): options = {'input_format': 'keras', 'input_path': 'tmp/model.HD5', 'weight_shard_size_bytes': '100', - 'quantization_bytes': 1, + 'quantize_uint16': 'conv/*/weights', 'output_path': 'tmp/web_model'} - self.assertEqual(['--input_format=keras', '--quantization_bytes=1', + self.assertEqual(['--input_format=keras', + '--quantize_uint16=conv/*/weights', '--weight_shard_size_bytes=100', 'tmp/model.HD5', 'tmp/web_model'], wizard.generate_arguments(options)) @@ -237,13 +241,13 @@ def testGenerateCommandForLayerModel(self): options = {'input_format': 'tfjs_layers_model', 'output_format': 'keras', 'input_path': 'tmp/model.json', - 'quantization_bytes': 1, + 'quantize_uint8': 'conv/*/weights', 'weight_shard_size_bytes': '100', 'output_path': 'tmp/web_model'} self.assertEqual(['--input_format=tfjs_layers_model', '--output_format=keras', - '--quantization_bytes=1', + '--quantize_uint8=conv/*/weights', '--weight_shard_size_bytes=100', 'tmp/model.json', 'tmp/web_model'], diff --git a/tfjs-converter/python/tensorflowjs/quantization.py b/tfjs-converter/python/tensorflowjs/quantization.py index a26259067c6..11b535334bb 100644 --- a/tfjs-converter/python/tensorflowjs/quantization.py +++ b/tfjs-converter/python/tensorflowjs/quantization.py @@ -16,10 +16,79 @@ from __future__ import division from __future__ import print_function +import fnmatch import numpy as np -QUANTIZATION_BYTES_TO_DTYPES = {1: np.uint8, 2: np.uint16} +QUANTIZATION_DTYPE_FLOAT16 = 'float16' +QUANTIZATION_DTYPE_UINT8 = 'uint8' +QUANTIZATION_DTYPE_UINT16 = 'uint16' +QUANTIZATION_BYTES_TO_DTYPES = {1: QUANTIZATION_DTYPE_UINT8, + 2: QUANTIZATION_DTYPE_UINT16} +QUANTIZATION_OPTION_TO_DTYPES = {QUANTIZATION_DTYPE_UINT8: np.uint8, + QUANTIZATION_DTYPE_UINT16: np.uint16, + QUANTIZATION_DTYPE_FLOAT16: np.float16} + + +def map_layers_to_quantization_dtype(names, quantization_dtype_map): + """Maps node names to their quantization dtypes. + + Given a quantization_dtype_map which maps dtypes `uint8`, `uint16`, `float16` + to node patterns, e.g., conv/*/weights we construct a new mapping for each + individual node name to its dtype, e.g., conv/1/weight -> `uint8`. + A dtype in the map can also be a boolean, signaling a fallthrough dtype. + There can only be one fallthrough dtype in the map. A fallthrough dtype + will convert all weights that don't match any pattern to the provided dtype. + + Args: + names: Array of node names. + quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`) + to weights. The weight mapping supports wildcard substitution. + + Returns: + quantization_dtype: A mapping from each node name which matches + an entry in quantization_dtype_map to its corresponding dtype. + + Raises: + ValueError: - If multiple dtypes match the same node name + - If more than one fallthrough is provided + """ + if quantization_dtype_map is None: + return {} + + fallthrough = None + quantization_dtype = {} + for dtype_name, patterns in quantization_dtype_map.items(): + # Record fallthrough if there is one + if isinstance(patterns, bool) and patterns: + # Only one fallthrough is supported + if fallthrough is not None: + raise ValueError( + 'More than one quantization fallthrough provided, ' + 'exactly one is supported') + fallthrough = dtype_name + continue + elif isinstance(patterns, str): + patterns = list([patterns]) + + # Record matched weights for dtype + for pattern in patterns: + for match in fnmatch.filter(names, pattern): + dtype = QUANTIZATION_OPTION_TO_DTYPES[dtype_name] + if match in quantization_dtype and quantization_dtype[match] != dtype: + raise ValueError( + 'Two quantization values %s, %s match the same node %s' % + (dtype, quantization_dtype[match], match)) + quantization_dtype[match] = dtype + + # Catch all remaining names with fallthrough + if fallthrough is not None: + nameset = set(names) + fallthrough_names = nameset - set(quantization_dtype.keys()) + for name in fallthrough_names: + quantization_dtype[name] = QUANTIZATION_OPTION_TO_DTYPES[fallthrough] + + return quantization_dtype def quantize_weights(data, quantization_dtype): """Quantizes the weights by linearly re-scaling across available bits. @@ -36,43 +105,71 @@ def quantize_weights(data, quantization_dtype): Args: data: A numpy array of dtype 'float32' or 'int32'. - quantization_dtype: A numpy dtype to quantize weights to. Only np.uint8 and - np.uint16 are supported. + quantization_dtype: A numpy dtype to quantize weights to. Only np.float16, + np.uint8, and np.uint16 are supported. Returns: quantized_data: The quantized weights as a numpy array with dtype `quantization_dtype`. - scale: The linearly scaling constant used for quantization. - min_val: The minimum value of the linear range. + metadata: A dictionary with the corresponding metadata for the quantization + type. There is no metadata associated with float16. + For affine quantization there are two associated metadata values: + scale: The linearly scaling constant used for quantization. + min_val: The minimum value of the linear range. Raises: ValueError: if `quantization_dtype` is not a valid type. """ - if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values(): + if quantization_dtype in [np.uint8, np.uint16]: + # Compute the min and max for the group. + min_val = data.min().astype(np.float64) + max_val = data.max().astype(np.float64) + if min_val == max_val: + # If there is only a single value, we can represent everything as zeros. + quantized_data = np.zeros_like(data, dtype=quantization_dtype) + scale = 1.0 + else: + # Quantize data. + scale, min_val, max_val = _get_affine_quantization_range( + min_val, max_val, quantization_dtype) + quantized_data = np.round( + (data.clip(min_val, max_val) - min_val) / scale).astype( + quantization_dtype) + + return quantized_data, {'min': min_val, 'scale': scale} + elif quantization_dtype == np.float16: + if data.dtype != np.float32: + raise ValueError( + 'Invalid data dtype %r\n' + 'float16 quantization only supports float32 dtype' % data.dtype) + quantized_data = data.astype(np.float16) + return quantized_data, {} + else: raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype) - # Compute the min and max for the group. - min_val = data.min().astype(np.float64) - max_val = data.max().astype(np.float64) - if min_val == max_val: - # If there is only a single value, we can represent everything as zeros. - quantized_data = np.zeros_like(data, dtype=quantization_dtype) - scale = 1.0 - else: - # Quantize data. - scale, min_val, max_val = _get_quantization_range( - min_val, max_val, quantization_dtype) - quantized_data = np.round( - (data.clip(min_val, max_val) - min_val) / scale).astype( - quantization_dtype) - return quantized_data, scale, min_val +def dequantize_weights(data, metadata, original_dtype=np.float32): + dtype = data.dtype -def dequantize_weights( - quantized_data, scale, min_val, original_dtype=np.float32): - return np.round(quantized_data * scale + min_val).astype(original_dtype) + if dtype in [np.uint8, np.uint16]: + if not ('scale' in metadata and 'min' in metadata): + raise ValueError( + 'Missing metadata min or scale for dtype %s' % dtype.name) + scale = metadata['scale'] + min_val = metadata['min'] + return np.round(data * scale + min_val).astype(original_dtype) + elif dtype == np.float16: + if original_dtype != np.float32: + raise ValueError( + 'Invalid data dtype %r\n' + 'float16 quantization only supports float32 dtype' % data.dtype) + return data.astype(original_dtype) + else: + raise ValueError( + 'Invalid dtype %s for dequantization\n' + 'Supported dtypes are uint8, uint16, float16' % dtype.name) -def _get_quantization_range(min_val, max_val, quantization_dtype): +def _get_affine_quantization_range(min_val, max_val, quantization_dtype): """Computes quantization range to ensure that zero is represented if covered. Gymnastics with nudged zero point is to ensure that real zero maps to an @@ -97,7 +194,7 @@ def _get_quantization_range(min_val, max_val, quantization_dtype): Raises: ValueError: if `quantization_dtype` is not a valid type. """ - if quantization_dtype not in QUANTIZATION_BYTES_TO_DTYPES.values(): + if quantization_dtype not in [np.uint8, np.uint16]: raise ValueError('Invalid `quantization_dtype`: %r' % quantization_dtype) quant_max = np.iinfo(quantization_dtype).max diff --git a/tfjs-converter/python/tensorflowjs/quantization_test.py b/tfjs-converter/python/tensorflowjs/quantization_test.py index 13207e8e8ea..4762b5ecbb2 100644 --- a/tfjs-converter/python/tensorflowjs/quantization_test.py +++ b/tfjs-converter/python/tensorflowjs/quantization_test.py @@ -24,71 +24,219 @@ class TestQuantizationUtil(unittest.TestCase): + def assertDictContainsSubsetAlmostEqual(self, d1, d2): + self.assertIsInstance(d1, dict) + self.assertIsInstance(d2, dict) + + d1_keys = set(d1.keys()) + d2_keys = set(d2.keys()) + + self.assertTrue(d2_keys.issubset(d1_keys)) + + for key in d2_keys: + self.assertAlmostEqual(d1[key], d2[key]) + + def _runQuantizeTest( self, range_min, range_max, data_dtype, quantization_dtype, - expected_scale): + expected_metadata): d = np.arange(range_min, range_max + 1, dtype=data_dtype) - q, s, m = quantization.quantize_weights(d, quantization_dtype) - self.assertAlmostEqual(s, expected_scale) + q, metadata = quantization.quantize_weights(d, quantization_dtype) + + self.assertDictContainsSubsetAlmostEqual(metadata, expected_metadata) self.assertEqual(q.dtype, quantization_dtype) - de_q = quantization.dequantize_weights(q, s, m, data_dtype) + de_q = quantization.dequantize_weights( + q, metadata, data_dtype) np.testing.assert_allclose(de_q, d) - if range_min <= 0 <= range_max: - d_0 = np.zeros(1, data_dtype) - q_0 = np.round((d_0 - m) / s).astype(quantization_dtype) - self.assertEqual( - quantization.dequantize_weights(q_0, s, m, data_dtype), d_0) + if quantization_dtype in [np.uint8, np.uint16]: + s = metadata['scale'] + m = metadata['min'] + if range_min <= 0 <= range_max: + d_0 = np.zeros(1, data_dtype) + q_0 = np.round((d_0 - m) / s).astype(quantization_dtype) + self.assertEqual( + quantization.dequantize_weights(q_0, metadata, data_dtype), d_0) - def testAllEqual(self): + def testAffineQuantizeAllEqual(self): d = np.ones(5, dtype=np.float32) - q, s, m = quantization.quantize_weights(d, np.uint8) - self.assertEqual(s, 1.0) + q, metadata = quantization.quantize_weights(d, np.uint8) + assert 'scale' in metadata and 'min' in metadata + self.assertEqual(metadata['scale'], 1.0) self.assertEqual(q.dtype, np.uint8) - de_q = quantization.dequantize_weights(q, s, m, np.float32) + de_q = quantization.dequantize_weights(q, metadata, np.float32) np.testing.assert_array_equal(de_q, d) - def testQuantizeNegativeFloats(self): - self._runQuantizeTest(-3, -1, np.float32, np.uint8, expected_scale=2/255) - self._runQuantizeTest(-3, -1, np.float32, np.uint16, expected_scale=2/65536) - - def testQuantizeNegativeAndZeroFloats(self): - self._runQuantizeTest(-3, 0, np.float32, np.uint8, expected_scale=3/255) - self._runQuantizeTest(-3, 0, np.float32, np.uint16, expected_scale=3/65536) - - def testQuantizeNegativeAndPositiveFloats(self): - self._runQuantizeTest(-3, 3, np.float32, np.uint8, expected_scale=6/255) - self._runQuantizeTest(-3, 3, np.float32, np.uint16, expected_scale=6/65536) - - def testQuantizeZeroAndPositiveFloats(self): - self._runQuantizeTest(0, 3, np.float32, np.uint8, expected_scale=3/255) - self._runQuantizeTest(0, 3, np.float32, np.uint16, expected_scale=3/65536) - - def testQuantizePositiveFloats(self): - self._runQuantizeTest(1, 3, np.float32, np.uint8, expected_scale=2/255) - self._runQuantizeTest(1, 3, np.float32, np.uint16, expected_scale=2/65536) - - def testQuantizeNegativeInts(self): - self._runQuantizeTest(-3, -1, np.int32, np.uint8, expected_scale=2/255) - self._runQuantizeTest(-3, -1, np.int32, np.uint16, expected_scale=2/65536) - - def testQuantizeNegativeAndZeroInts(self): - self._runQuantizeTest(-3, 0, np.int32, np.uint8, expected_scale=3/255) - self._runQuantizeTest(-3, 0, np.int32, np.uint16, expected_scale=3/65536) + def testFloatQuantizeAllEqual(self): + d = np.ones(5, dtype=np.float32) + q, metadata = quantization.quantize_weights(d, np.float16) + self.assertDictEqual(metadata, {}) - def testQuantizeNegativeAndPositiveInts(self): - self._runQuantizeTest(-3, 3, np.int32, np.uint8, expected_scale=6/255) - self._runQuantizeTest(-3, 3, np.int32, np.uint16, expected_scale=6/65536) + self.assertEqual(q.dtype, np.float16) + de_q = quantization.dequantize_weights(q, metadata, np.float32) + np.testing.assert_array_equal(de_q, d) - def testQuantizeZeroAndPositiveInts(self): - self._runQuantizeTest(0, 3, np.int32, np.uint8, expected_scale=3/255) - self._runQuantizeTest(0, 3, np.int32, np.uint16, expected_scale=3/65536) + def testAffineQuantizeNegativeFloats(self): + self._runQuantizeTest( + -3, -1, np.float32, np.uint8, + expected_metadata={'scale': 2/255}) + self._runQuantizeTest( + -3, -1, np.float32, np.uint16, + expected_metadata={'scale': 2/65536}) + + def testAffineQuantizeNegativeAndZeroFloats(self): + self._runQuantizeTest( + -3, 0, np.float32, np.uint8, + expected_metadata={'scale': 3/255}) + self._runQuantizeTest( + -3, 0, np.float32, np.uint16, + expected_metadata={'scale': 3/65536}) + + def testAffineQuantizeNegativeAndPositiveFloats(self): + self._runQuantizeTest( + -3, 3, np.float32, np.uint8, + expected_metadata={'scale': 6/255}) + self._runQuantizeTest( + -3, 3, np.float32, np.uint16, + expected_metadata={'scale': 6/65536}) + + def testAffineQuantizeZeroAndPositiveFloats(self): + self._runQuantizeTest( + 0, 3, np.float32, np.uint8, + expected_metadata={'scale': 3/255}) + self._runQuantizeTest( + 0, 3, np.float32, np.uint16, + expected_metadata={'scale': 3/65536}) + + def testAffineQuantizePositiveFloats(self): + self._runQuantizeTest( + 1, 3, np.float32, np.uint8, + expected_metadata={'scale': 2/255}) + self._runQuantizeTest( + 1, 3, np.float32, np.uint16, + expected_metadata={'scale': 2/65536}) + + def testAffineQuantizeNegativeInts(self): + self._runQuantizeTest( + -3, -1, np.int32, np.uint8, + expected_metadata={'scale': 2/255}) + self._runQuantizeTest( + -3, -1, np.int32, np.uint16, + expected_metadata={'scale': 2/65536}) + + def testAffineQuantizeNegativeAndZeroInts(self): + self._runQuantizeTest( + -3, 0, np.int32, np.uint8, + expected_metadata={'scale': 3/255}) + self._runQuantizeTest( + -3, 0, np.int32, np.uint16, + expected_metadata={'scale': 3/65536}) + + def testAffineQuantizeNegativeAndPositiveInts(self): + self._runQuantizeTest( + -3, 3, np.int32, np.uint8, + expected_metadata={'scale': 6/255}) + self._runQuantizeTest( + -3, 3, np.int32, np.uint16, + expected_metadata={'scale': 6/65536}) + + def testAffineQuantizeZeroAndPositiveInts(self): + self._runQuantizeTest( + 0, 3, np.int32, np.uint8, + expected_metadata={'scale': 3/255}) + self._runQuantizeTest( + 0, 3, np.int32, np.uint16, + expected_metadata={'scale': 3/65536}) + + def testAffineQuantizePositiveInts(self): + self._runQuantizeTest( + 1, 3, np.int32, np.uint8, + expected_metadata={'scale': 2/255}) + self._runQuantizeTest( + 1, 3, np.int32, np.uint16, + expected_metadata={'scale': 2/65536}) + + def testInvalidQuantizationTypes(self): + # Invalid quantization type + with self.assertRaises(ValueError): + quantization.quantize_weights(np.array([]), np.bool) + + # Invalid data dtype for float16 quantization + with self.assertRaises(ValueError): + d = np.ones(1, dtype=np.int32) + quantization.quantize_weights(d, np.float16) + + def testInvalidDequantizationTypes(self): + # Invalid metadata for affine quantization + with self.assertRaises(ValueError): + d = np.ones(1, dtype=np.uint8) + quantization.dequantize_weights(np.array([]), {}) + + # Invalid target dtype for float16 quantization + with self.assertRaises(ValueError): + d = np.ones(1, dtype=np.float16) + quantization.dequantize_weights(d, {}, np.int32) + + # Invalid dequantization type + with self.assertRaises(ValueError): + d = np.ones(1, dtype=np.bool) + quantization.dequantize_weights(d, {}) + + def testMapLayerFallthrough(self): + names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias'] + quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': True} + dtype_map = quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) + + self.assertDictEqual(dtype_map, { + 'conv/0/weight': np.float16, + 'conv/0/bias': np.float16, + 'conv/1/weight': np.uint8, + 'conv/1/bias': np.uint8 + }) + + def testMapLayerConflictingMap(self): + names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias'] + quantization_dtype_map = {'float16': ['conv/0/*'], 'uint8': ['conv/0/bias']} + + with self.assertRaises(ValueError): + quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) + + + def testMapLayerStringToList(self): + names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias'] + quantization_dtype_map = {'float16': '*'} + + + dtype_map = quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) + + self.assertDictEqual(dtype_map, { + 'conv/0/weight': np.float16, + 'conv/0/bias': np.float16, + 'conv/1/weight': np.float16, + 'conv/1/bias': np.float16 + }) + + def testMapLayerNoDtypeMap(self): + names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias'] + quantization_dtype_map = {} + dtype_map = quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) + + self.assertDictEqual(dtype_map, {}) + + def testMapLayerExactlyOneFallthrough(self): + names = ['conv/0/weight', 'conv/0/bias', 'conv/1/weight', 'conv/1/bias'] + quantization_dtype_map = {'float16': True, 'uint8': True} + + with self.assertRaises(ValueError): + quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) - def testQuantizePositiveInts(self): - self._runQuantizeTest(1, 3, np.int32, np.uint8, expected_scale=2/255) - self._runQuantizeTest(1, 3, np.int32, np.uint16, expected_scale=2/65536) if __name__ == '__main__': diff --git a/tfjs-converter/python/tensorflowjs/read_weights.py b/tfjs-converter/python/tensorflowjs/read_weights.py index 5f5d433faa9..26485b7d173 100644 --- a/tfjs-converter/python/tensorflowjs/read_weights.py +++ b/tfjs-converter/python/tensorflowjs/read_weights.py @@ -24,7 +24,7 @@ import numpy as np from tensorflowjs import quantization -_INPUT_DTYPES = [np.float32, np.int32, np.complex64, +_INPUT_DTYPES = [np.float16, np.float32, np.int32, np.complex64, np.uint8, np.uint16, np.object] # Number of bytes used to encode the length of a string in a string tensor. @@ -191,8 +191,7 @@ def decode_weights(weights_manifest, data_buffers, flatten=False): offset += dtype.itemsize * value.size if quant_info: value = quantization.dequantize_weights( - value, quant_info['scale'], quant_info['min'], - np.dtype(weight['dtype'])) + value, quant_info, np.dtype(weight['dtype'])) out_group.append({'name': name, 'data': value}) if flatten: diff --git a/tfjs-converter/python/tensorflowjs/read_weights_test.py b/tfjs-converter/python/tensorflowjs/read_weights_test.py index 61af1d3212f..3cca2170d3a 100644 --- a/tfjs-converter/python/tensorflowjs/read_weights_test.py +++ b/tfjs-converter/python/tensorflowjs/read_weights_test.py @@ -326,7 +326,7 @@ def testReadWeightsWithIncorrectTypeInWeightsManifestRaisesError(self): read_weights.read_weights(groups[0][0], self._tmp_dir) - def testReadQuantizedWeights(self): + def testReadAffineQuantizedWeights(self): groups = [ [{ 'name': 'weight1', @@ -335,13 +335,34 @@ def testReadQuantizedWeights(self): ] manifest = write_weights.write_weights( - groups, self._tmp_dir, quantization_dtype=np.uint8) + groups, self._tmp_dir, quantization_dtype_map={'uint8': '*'}) # Read the weights using `read_weights`. read_output = read_weights.read_weights(manifest, self._tmp_dir) self.assertEqual(1, len(read_output)) self.assertEqual(1, len(read_output[0])) self.assertEqual('weight1', read_output[0][0]['name']) + self.assertEqual(read_output[0][0]['data'].dtype, np.float32) + self.assertTrue( + np.allclose(groups[0][0]['data'], read_output[0][0]['data'])) + + def testReadFloat16QuantizedWeights(self): + groups = [ + [{ + 'name': 'weight1', + 'data': np.array([0, 1, 2, 3], 'float32') + }] + ] + + manifest = write_weights.write_weights( + groups, self._tmp_dir, quantization_dtype_map={'float16': '*'}) + + # Read the weights using `read_weights`. + read_output = read_weights.read_weights(manifest, self._tmp_dir) + self.assertEqual(1, len(read_output)) + self.assertEqual(1, len(read_output[0])) + self.assertEqual('weight1', read_output[0][0]['name']) + self.assertEqual(read_output[0][0]['data'].dtype, np.float32) self.assertTrue( np.allclose(groups[0][0]['data'], read_output[0][0]['data'])) diff --git a/tfjs-converter/python/tensorflowjs/write_weights.py b/tfjs-converter/python/tensorflowjs/write_weights.py index fbbcbc44cd0..7603fb67014 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights.py +++ b/tfjs-converter/python/tensorflowjs/write_weights.py @@ -24,7 +24,7 @@ from tensorflowjs import quantization from tensorflowjs import read_weights -_OUTPUT_DTYPES = [np.float32, np.int32, np.complex64, +_OUTPUT_DTYPES = [np.float16, np.float32, np.int32, np.complex64, np.uint8, np.uint16, np.bool, np.object] _AUTO_DTYPE_CONVERSION = { np.dtype(np.float64): np.float32, @@ -33,7 +33,7 @@ def write_weights( weight_groups, write_dir, shard_size_bytes=1024 * 1024 * 4, - write_manifest=True, quantization_dtype=None): + write_manifest=True, quantization_dtype_map=None): """Writes weights to a binary format on disk for ingestion by JavaScript. Weights are organized into groups. When writing to disk, the bytes from all @@ -44,7 +44,7 @@ def write_weights( shard size. Weights are optionally quantized to either 8 or 16 bits for compression, - which is enabled via the `quantization_dtype` argument. + which is enabled via the `quantization_dtype_map`. Args: weight_groups: An list of groups. Each group is an array of weight @@ -68,8 +68,9 @@ def write_weights( the max file size for caching for all major browsers. write_manifest: Whether to write the manifest JSON to disk. Defaults to True. - quantization_dtype: An optional numpy dtype to quantize weights to for - compression. Only np.uint8 and np.uint16 are supported. + quantization_dtype_map: (Optional) A mapping from dtype + (`uint8`, `uint16`, `float16`) to weights names. The weight mapping + supports wildcard substitution. Returns: The weights manifest JSON dict. @@ -107,7 +108,7 @@ def write_weights( 'name': 'weight2', 'shape': [2000, 2000], 'dtype': 'float32', - 'quantization': {'min': -2.4, 'scale': 0.08, 'dtype': 'uint8'} + 'quantization': {'dtype': 'float16'} }] }] """ @@ -120,8 +121,14 @@ def write_weights( for group_index, group in enumerate(weight_groups): for e in group: _auto_convert_weight_entry(e) - if quantization_dtype: - group = [_quantize_entry(e, quantization_dtype) for e in group] + names = [entry['name'] for entry in group] + quantization_dtype = quantization.map_layers_to_quantization_dtype( + names, quantization_dtype_map) + + group = [ + _quantize_entry(e, quantization_dtype[e['name']]) + if e['name'] in quantization_dtype else e for e in group + ] group_bytes, total_bytes, _ = _stack_group_bytes(group) shard_filenames = _shard_group_bytes_to_disk( @@ -154,8 +161,8 @@ def _quantize_entry(entry, quantization_dtype): Args: entry: A weight entries to quantize. - quantization_dtype: An numpy dtype to quantize weights to. Only np.uint8 and - np.uint16 are supported. + quantization_dtype: An numpy dtype to quantize weights to. + Only np.uint8, np.uint16, and np.float16 are supported. Returns: A new entry containing the quantized data and additional quantization info, @@ -168,19 +175,19 @@ def _quantize_entry(entry, quantization_dtype): 'name': 'weight1', 'data': np.array([20, 0, 255], 'uint8') 'quantization': {'min': -0.10196078817, 'scale': 0.00509803940852, - 'original_dtype': 'float32'} + 'dtype': 'uint8', 'original_dtype': 'float32'} } """ data = entry['data'] # Only float32 tensors are quantized. if data.dtype != 'float32': return entry - quantized_data, scale, min_val = quantization.quantize_weights( + quantized_data, metadata = quantization.quantize_weights( data, quantization_dtype) + metadata.update({'original_dtype': data.dtype.name}) quantized_entry = entry.copy() quantized_entry['data'] = quantized_data - quantized_entry['quantization'] = { - 'min': min_val, 'scale': scale, 'original_dtype': data.dtype.name} + quantized_entry['quantization'] = metadata return quantized_entry @@ -322,11 +329,9 @@ def _get_weights_manifest_for_group(group): if dtype == 'object': var_manifest['dtype'] = 'string' if is_quantized: - var_manifest['quantization'] = { - 'min': entry['quantization']['min'], - 'scale': entry['quantization']['scale'], - 'dtype': entry['data'].dtype.name - } + manifest = {'dtype': entry['data'].dtype.name} + manifest.update(entry['quantization']) + var_manifest['quantization'] = manifest weights_entries.append(var_manifest) return weights_entries diff --git a/tfjs-converter/python/tensorflowjs/write_weights_test.py b/tfjs-converter/python/tensorflowjs/write_weights_test.py index 2ea37b5b159..3459edd41d5 100644 --- a/tfjs-converter/python/tensorflowjs/write_weights_test.py +++ b/tfjs-converter/python/tensorflowjs/write_weights_test.py @@ -717,7 +717,11 @@ def test_quantize_group(self): ] manifest = write_weights.write_weights( - groups, TMP_DIR, shard_size_bytes=1024, quantization_dtype=np.uint8) + groups, TMP_DIR, shard_size_bytes=1024, + quantization_dtype_map={ + 'float16': 'weight1', + 'uint8': 'weight3' + }) self.assertTrue( os.path.isfile(os.path.join(TMP_DIR, 'weights_manifest.json')), @@ -731,7 +735,8 @@ def test_quantize_group(self): 'shape': [3], 'dtype': 'float32', 'quantization': { - 'min': 1.0, 'scale': 2/255.0, 'dtype': 'uint8' + 'original_dtype': 'float32', + 'dtype': 'float16' } }, { 'name': 'weight2', @@ -742,7 +747,10 @@ def test_quantize_group(self): 'shape': [2], 'dtype': 'float32', 'quantization': { - 'min': 6.0, 'scale': 1/255.0, 'dtype': 'uint8' + 'min': 6.0, + 'scale': 1/255.0, + 'original_dtype': 'float32', + 'dtype': 'uint8' } }, { 'name': 'weight4', @@ -754,19 +762,19 @@ def test_quantize_group(self): weights_path = os.path.join(TMP_DIR, 'group1-shard1of1.bin') with open(weights_path, 'rb') as f: weight_bytes = f.read() - self.assertEqual(len(weight_bytes), 22) - w1 = np.frombuffer(weight_bytes[:3], 'uint8') - np.testing.assert_array_equal(w1, np.array([0, 127, 255], 'uint8')) + self.assertEqual(len(weight_bytes), 25) + w1 = np.frombuffer(weight_bytes[:6], 'float16') + np.testing.assert_array_equal(w1, np.array([1, 2, 3], 'float16')) - w2 = np.frombuffer(weight_bytes[3:11], 'int32') + w2 = np.frombuffer(weight_bytes[6:14], 'int32') np.testing.assert_array_equal(w2, np.array([4, 5], 'int32')) - w3 = np.frombuffer(weight_bytes[11:13], 'uint8') + w3 = np.frombuffer(weight_bytes[14:16], 'uint8') np.testing.assert_array_equal(w3, np.array([0, 255], 'uint8')) - size = np.frombuffer(weight_bytes[13:17], 'uint32')[0] + size = np.frombuffer(weight_bytes[16:20], 'uint32')[0] self.assertEqual(size, 5) # 5 ascii letters. - w4 = weight_bytes[17:].decode('utf-8') + w4 = weight_bytes[20:].decode('utf-8') self.assertEqual(w4, u'hello') diff --git a/tfjs-converter/python/test_pip_package.py b/tfjs-converter/python/test_pip_package.py index 8926417e8e4..f8b6a473b42 100644 --- a/tfjs-converter/python/test_pip_package.py +++ b/tfjs-converter/python/test_pip_package.py @@ -871,7 +871,7 @@ def testConvertTfjsLayersModelIntoShardedWeights(self): new_y = new_model.predict(x) self.assertAllClose(new_y, y) - def testConvertTfjsLayersModelWithQuantization(self): + def testConvertTfjsLayersModelWithLegacyQuantization(self): with tf.Graph().as_default(), tf.compat.v1.Session(): # 1. Saved the model as a SavedModel. model = self._createNestedSequentialModel() @@ -911,6 +911,47 @@ def testConvertTfjsLayersModelWithQuantization(self): # The size of the weight file should reflect the uint16 quantization. self.assertEqual(weight_file_size, total_weight_bytes // 2) + + def testConvertTfjsLayersModelWithQuantization(self): + with tf.Graph().as_default(), tf.compat.v1.Session(): + # 1. Saved the model as a SavedModel. + model = self._createNestedSequentialModel() + + weights = model.get_weights() + total_weight_bytes = sum(np.size(w) for w in weights) * 4 + + tf.keras.models.save_model(model, self._tmp_dir) + + # 2. Convert the keras saved model to tfjs_layers_model format. + tfjs_output_dir = os.path.join(self._tmp_dir, 'tfjs') + # Implicit value of --output_format: tfjs_layers_model + process = subprocess.Popen([ + 'tensorflowjs_converter', '--input_format', 'keras_saved_model', + self._tmp_dir, tfjs_output_dir + ]) + process.communicate() + self.assertEqual(0, process.returncode) + + # 3. Convert the tfjs_layers_model to another tfjs_layers_model, + # with uint16 quantization. + sharded_model_dir = os.path.join(self._tmp_dir, 'tfjs_sharded') + process = subprocess.Popen([ + 'tensorflowjs_converter', '--input_format', 'tfjs_layers_model', + '--output_format', 'tfjs_layers_model', + '--quantize_uint16', '*', + os.path.join(tfjs_output_dir, 'model.json'), sharded_model_dir + ]) + process.communicate() + self.assertEqual(0, process.returncode) + + # 4. Check the quantized weight file and its size. + weight_files = sorted( + glob.glob(os.path.join(sharded_model_dir, 'group*.bin'))) + self.assertEqual(len(weight_files), 1) + weight_file_size = os.path.getsize(weight_files[0]) + # The size of the weight file should reflect the uint16 quantization. + self.assertEqual(weight_file_size, total_weight_bytes // 2) + def testConvertTfjsLayersModelToTfjsGraphModel(self): with tf.Graph().as_default(), tf.compat.v1.Session(): # 1. Create a model for testing. From 2fea4ee32c118d4ebc60c9179090c5772e4fcf32 Mon Sep 17 00:00:00 2001 From: Jesse Farebrother Date: Wed, 27 May 2020 21:07:08 -0600 Subject: [PATCH 2/2] IO decode support for float16 quantization --- tfjs-core/src/io/io_utils.ts | 149 ++++++++++++++++++++++++++++-- tfjs-core/src/io/io_utils_test.ts | 77 ++++++++++++++- tfjs-core/src/io/types.ts | 7 +- 3 files changed, 223 insertions(+), 10 deletions(-) diff --git a/tfjs-core/src/io/io_utils.ts b/tfjs-core/src/io/io_utils.ts index 1387152e2c0..d83d4e3646b 100644 --- a/tfjs-core/src/io/io_utils.ts +++ b/tfjs-core/src/io/io_utils.ts @@ -114,6 +114,7 @@ export function decodeWeights( buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { // TODO(adarob, cais): Support quantization. const out: NamedTensorMap = {}; + let float16Decode: (buffer: Uint16Array) => Float32Array | undefined; let offset = 0; for (const spec of specs) { const name = spec.name; @@ -124,11 +125,26 @@ export function decodeWeights( if ('quantization' in spec) { const quantization = spec.quantization; - if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + if (!('min' in quantization && 'scale' in quantization)) { + throw new Error( + `Weight ${spec.name} with quantization ${quantization.dtype} ` + + `doesn't have corresponding metadata min and scale.` + ); + } + } else if (quantization.dtype === 'float16') { + if (dtype !== 'float32') { + throw new Error( + `Weight ${spec.name} is quantized with ${quantization.dtype} ` + + `which only supports weights of type float32 not ${dtype}.` + ); + } + } else { throw new Error( `Weight ${spec.name} has unknown ` + `quantization dtype ${quantization.dtype}. ` + - `Supported quantization dtypes are: 'uint8' and 'uint16'.`); + `Supported quantization dtypes are: ` + + `'uint8', 'uint16', and 'float16'.`); } const quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization.dtype]; const byteBuffer = @@ -137,12 +153,30 @@ export function decodeWeights( new Uint8Array(byteBuffer) : new Uint16Array(byteBuffer); if (dtype === 'float32') { - values = new Float32Array(quantizedArray.length); - for (let i = 0; i < quantizedArray.length; i++) { - const v = quantizedArray[i]; - values[i] = v * quantization.scale + quantization.min; + if (quantization.dtype === 'uint8' || quantization.dtype === 'uint16') { + values = new Float32Array(quantizedArray.length); + for (let i = 0; i < quantizedArray.length; i++) { + const v = quantizedArray[i]; + values[i] = v * quantization.scale + quantization.min; + } + } else if (quantization.dtype === 'float16') { + if (float16Decode === undefined) { + float16Decode = getFloat16Decoder(); + } + values = float16Decode(quantizedArray as Uint16Array); + } else { + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type float32.` + ); } } else if (dtype === 'int32') { + if (quantization.dtype !== 'uint8' && quantization.dtype !== 'uint16') { + throw new Error( + `Unsupported quantization type ${quantization.dtype} ` + + `for weight type int32.` + ); + } values = new Int32Array(quantizedArray.length); for (let i = 0; i < quantizedArray.length; i++) { const v = quantizedArray[i]; @@ -363,3 +397,106 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts): modelArtifacts.weightData.byteLength, }; } + +/** + * Computes mantisa table for casting Float16 to Float32 + * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf + * + * @returns Uint32Array, 2048 mantissa lookup values. + */ +function computeFloat16MantisaTable(): Uint32Array { + const convertMantissa = (i: number): number => { + let m = i << 13; + let e = 0; + + while ((m & 0x00800000) === 0) { + e -= 0x00800000; + m <<= 1; + } + m &= ~0x00800000; + e += 0x38800000; + + return m | e; + }; + + const mantisaTable = new Uint32Array(2048); + + mantisaTable[0] = 0; + for (let i = 1; i < 1024; i++) { + mantisaTable[i] = convertMantissa(i); + } + for (let i = 1024; i < 2048; i++) { + mantisaTable[i] = 0x38000000 + ((i - 1024) << 13); + } + + return mantisaTable; +} + +/** + * Computes exponent table for casting Float16 to Float32 + * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf + * + * @returns Uint32Array, 64 exponent lookup values. + */ +function computeFloat16ExponentTable(): Uint32Array { + const exponentTable = new Uint32Array(64); + + exponentTable[0] = 0; + exponentTable[31] = 0x47800000; + exponentTable[32] = 0x80000000; + exponentTable[63] = 0xc7800000; + for (let i = 1; i < 31; i++) { + exponentTable[i] = i << 23; + } + for (let i = 33; i < 63; i++) { + exponentTable[i] = 0x80000000 + ((i - 32) << 23); + } + + return exponentTable; +} + +/** + * Computes offset table for casting Float16 to Float32 + * See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf + * + * @returns Uint32Array, 6d offset values. + */ +function computeFloat16OffsetTable(): Uint32Array { + const offsetTable = new Uint32Array(64); + + for (let i = 0; i < 64; i++) { + offsetTable[i] = 1024; + } + offsetTable[0] = offsetTable[32] = 0; + + return offsetTable; +} + +/** + * Retrieve a Float16 decoder which will decode a ByteArray of Float16 values + * to a Float32Array. + * + * @returns Function (buffer: Uint16Array) => Float32Array which decodes + * the Uint16Array of Float16 bytes to a Float32Array. + */ +export function getFloat16Decoder(): (buffer: Uint16Array) => Float32Array { + // Algorithm is based off of http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf + + // Cache lookup tables + const mantisaTable = computeFloat16MantisaTable(); + const exponentTable = computeFloat16ExponentTable(); + const offsetTable = computeFloat16OffsetTable(); + + return (quantizedArray: Uint16Array) => { + const buffer = new ArrayBuffer(4 * quantizedArray.length); + const bufferUint32View = new Uint32Array(buffer); + for (let index = 0; index < quantizedArray.length; index++) { + const float16Bits = quantizedArray[index]; + const float32Bits = + mantisaTable[offsetTable[float16Bits >> 10] + (float16Bits & 0x3ff)] + + exponentTable[float16Bits >> 10]; + bufferUint32View[index] = float32Bits; + } + return new Float32Array(buffer); + }; +} diff --git a/tfjs-core/src/io/io_utils_test.ts b/tfjs-core/src/io/io_utils_test.ts index 9ee1f35915a..6656ef43aa3 100644 --- a/tfjs-core/src/io/io_utils_test.ts +++ b/tfjs-core/src/io/io_utils_test.ts @@ -23,7 +23,7 @@ import {expectArraysEqual} from '../test_util'; import {expectArraysClose} from '../test_util'; import {encodeString} from '../util'; -import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength} from './io_utils'; +import {arrayBufferToBase64String, base64StringToArrayBuffer, basename, concatenateArrayBuffers, concatenateTypedArrays, stringByteLength, getFloat16Decoder} from './io_utils'; import {WeightsManifestEntry} from './types'; describe('concatenateTypedArrays', () => { @@ -565,6 +565,22 @@ describeWithFlags('decodeWeights', {}, () => { expect(weight1.shape).toEqual([3]); expect(weight1.dtype).toEqual('int32'); }); + it('support quantization float16 weights', async () => { + const manifestSpecs: WeightsManifestEntry[] = [ + { + name: 'weight0', + dtype: 'float32', + shape: [3], + quantization: { dtype: 'float16' }, + }, + ]; + const data = new Uint16Array([13312, 14336, 14848]); + const decoded = tf.io.decodeWeights(data.buffer, manifestSpecs); + const weight0 = decoded['weight0']; + expectArraysClose(await weight0.data(), [0.25, 0.5, 0.75]); + expect(weight0.shape).toEqual([3]); + expect(weight0.dtype).toEqual('float32'); + }); }); describe('stringByteLength', () => { @@ -654,3 +670,62 @@ describe('basename', () => { expect(basename('foo/bar/baz//')).toEqual('baz'); }); }); + +describe('float16', () => { + it('decodes NaN to float32 NaN', () => { + const decoder = getFloat16Decoder(); + const float16NaN = 0x00007e00; + const buffer = new Uint16Array([float16NaN]); + const f32 = decoder(buffer); + expect(f32).toEqual(new Float32Array([NaN])); + }); + + it('decodes ±Infinity to float32 ±Infinity', () => { + const decoder = getFloat16Decoder(); + const positiveInfinity = 0x00007c00; + const negativeInfinity = 0xfffffc00; + const buffer = new Uint16Array([positiveInfinity, negativeInfinity]); + const f32 = decoder(buffer); + expect(f32).toEqual(new Float32Array([Infinity, -Infinity])); + }); + + it('decodes ±0 to float32 ±0', () => { + const decoder = getFloat16Decoder(); + const positiveZero = 0x00000000; + const negativeZero = 0xffff8000; + const buffer = new Uint16Array([positiveZero, negativeZero]); + const f32 = decoder(buffer); + expect(f32).toEqual(new Float32Array([0.0, -0.0])); + }); + + it('decodes -Infinity on underflow', () => { + const decoder = getFloat16Decoder(); + const minVal = 0xfffffbff; + const buffer = new Uint16Array([minVal + 1]); + const f32 = decoder(buffer); + expect(f32).toEqual(new Float32Array([-Infinity])); + }); + + it('decodes +Infinity on overflow', () => { + const decoder = getFloat16Decoder(); + const maxVal = 0x00007bff; + const buffer = new Uint16Array([maxVal + 1]); + const f32 = decoder(buffer); + expect(f32).toEqual(new Float32Array([Infinity])); + }); + + it('decodes interpretable float16 to float32', () => { + const decoder = getFloat16Decoder(); + const buffer = new Uint16Array([ + 0x00003400, + 0x00003800, + 0x00003A00, + 0x00003555 + ]); + const f32 = decoder(buffer); + expect(f32[0]).toBeCloseTo(0.25); + expect(f32[1]).toBeCloseTo(0.5); + expect(f32[2]).toBeCloseTo(0.75); + expect(f32[3]).toBeCloseTo(0.333); + }); +}); diff --git a/tfjs-core/src/io/types.ts b/tfjs-core/src/io/types.ts index f4ee6ff0939..63c5e3deff0 100644 --- a/tfjs-core/src/io/types.ts +++ b/tfjs-core/src/io/types.ts @@ -22,6 +22,7 @@ */ export const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { 'float32': 4, + 'float16': 2, 'int32': 4, 'uint16': 2, 'uint8': 1, @@ -102,9 +103,9 @@ export declare interface WeightsManifestEntry { * Information for dequantization of the weight. */ quantization?: { - scale: number, // The scaling constant to multiply by. - min: number, // The (possibly nudged) minimum weight to add. - dtype: 'uint16'|'uint8' // The dtype of the quantized weights. + scale?: number, // The scaling constant to multiply by. + min?: number, // The (possibly nudged) minimum weight to add. + dtype: 'uint16'|'uint8'|'float16' // The dtype of the quantized weights. }; }