Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions tfjs-converter/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,10 @@ saved a tf.keras model in the SavedModel format.
|<nobr>`--saved_model_tags`</nobr> | Only applicable to SavedModel conversion. Tags of the MetaGraphDef to load, in comma separated format. Defaults to `serve`.|
|`--signature_name` | Only applicable to TensorFlow SavedModel and Hub module conversion, signature to load. Defaults to `serving_default` for SavedModel and `default` for Hub module. See https://www.tensorflow.org/hub/common_signatures/.|
|`--strip_debug_ops` | Strips out TensorFlow debug operations `Print`, `Assert`, `CheckNumerics`. Defaults to `True`.|
|`--quantization_bytes` | How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
|`--quantization_bytes` | (Deprecated) How many bytes to optionally quantize/compress the weights to. Valid values are 1 and 2. which will quantize int32 and float32 to 1 or 2 bytes respectively. The default (unquantized) size is 4 bytes.|
|`--quantize_float16` | Comma separated list of node names to apply float16 quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|`--quantize_uint8` | Comma separated list of node names to apply 1-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|`--quantize_uint16` | Comma separated list of node names to apply 2-byte affine quantization. You can also use wildcard symbol (*) to apply quantization to multiple nodes (e.g., conv/*/weights). When the flag is provided without any nodes the default behavior will match all nodes. |
|`--weight_shard_size_bytes` | Shard size (in bytes) of the weight files. Only supported when `output_format` is `tfjs_layers_model` or `tfjs_graph_model`. Default size is 4 MB (4194304 bytes).|
|<nobr>`--output_node_names`</nobr>| Only applicable to Frozen Model. The names of the output nodes, separated by commas.|

Expand Down Expand Up @@ -216,7 +219,7 @@ purposes:
tensorflowjs_converter \
--input_format tfjs_layers_model \
--output_format tfjs_layers_model \
--quantization_bytes 2 \
--quantize_uint16 \
original_model/model.json
quantized_model/
```
Expand Down Expand Up @@ -380,18 +383,33 @@ browser to cache them automatically. If the model architecture is less than 4MB

__4. Can I quantize the weights over the wire?__

Yes, you can use the --quantization_bytes option to compress int32/float32 to 1
or 2 bytes. Here is
an example of 8-bit quantization:
Yes, you can use the --quantize_{float16, uint8, uint16} flags to compress
weights with 1 byte integer quantization (`uint8`) or 2 byte integer
(`uint16`)/float (`float16`) quantization.
Quantizing to float16 may provide better accuracy over
2 byte affine integer scaling (`uint16`). 1-byte affine quantization,
i.e., `uint8` provides a 4x size reduction at the cost of accuracy.
For example, we can quantize our MobileNet model using float16 quantization:

```
tensorflowjs_converter \
tensorflowjs_converter
--quantize_float16 \
--input_format=tf_hub \
--quantization_bytes=1
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
```

You can also quantize specific weights as well as weight groupings using
a wildcard replacement. For example,
```
tensorflowjs_converter
--quantize_float16="conv/*/weights"
```
which will quantize all weights that match the pattern conv/*/weights.
This will exclude biases and any weights that don't begin with conv/.
This can be a powerful tool to reduce model size while trying to maximize
performance.

__5. Why is the predict() method for inference so much slower on the first call than the subsequent calls?__

The time of first call also includes the compilation time of WebGL shader
Expand Down
3 changes: 3 additions & 0 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
SIGNATURE_NAME = 'signature_name'
SAVED_MODEL_TAGS = 'saved_model_tags'
QUANTIZATION_BYTES = 'quantization_bytes'
QUANTIZATION_TYPE_FLOAT16 = 'quantize_float16'
QUANTIZATION_TYPE_UINT8 = 'quantize_uint8'
QUANTIZATION_TYPE_UINT16 = 'quantize_uint16'
SPLIT_WEIGHTS_BY_LAYER = 'split_weights_by_layer'
VERSION = 'version'
SKIP_OP_CHECK = 'skip_op_check'
Expand Down
136 changes: 92 additions & 44 deletions tfjs-converter/python/tensorflowjs/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import tempfile

import h5py
import numpy as np
import tensorflow.compat.v1 as tf1
import tensorflow.compat.v2 as tf

Expand All @@ -39,7 +38,7 @@


def dispatch_keras_h5_to_tfjs_layers_model_conversion(
h5_path, output_dir=None, quantization_dtype=None,
h5_path, output_dir=None, quantization_dtype_map=None,
split_weights_by_layer=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts a Keras HDF5 saved-model file to TensorFlow.js format.
Expand All @@ -56,8 +55,8 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion(
output_dir: Output directory to which the TensorFlow.js-format model JSON
file and weights files will be written. If the directory does not exist,
it will be created.
quantization_dtype: The quantized data type to store the weights in
(Default: `None`).
quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
to weights. The weight mapping supports wildcard substitution.
split_weights_by_layer: Whether to split the weights into separate weight
groups (corresponding to separate binary weight files) layer by layer
(Default: `False`).
Expand Down Expand Up @@ -94,15 +93,15 @@ def dispatch_keras_h5_to_tfjs_layers_model_conversion(
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
conversion.write_artifacts(
model_json, groups, output_dir, quantization_dtype,
model_json, groups, output_dir, quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)

return model_json, groups


def dispatch_keras_h5_to_tfjs_graph_model_conversion(
h5_path, output_dir=None,
quantization_dtype=None,
quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4,
Expand All @@ -115,8 +114,8 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
keras or tf.keras.
output_dir: The destination to which the tfjs GraphModel artifacts will be
written.
quantization_dtype: The quantized data type to store the weights in
(Default: `None`).
quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
to weights. The weight mapping supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
Expand All @@ -140,7 +139,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(
temp_savedmodel_dir, output_dir,
signature_def='serving_default',
saved_model_tags='serve',
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
Expand All @@ -151,7 +150,7 @@ def dispatch_keras_h5_to_tfjs_graph_model_conversion(


def dispatch_keras_saved_model_to_tensorflowjs_conversion(
keras_saved_model_path, output_dir, quantization_dtype=None,
keras_saved_model_path, output_dir, quantization_dtype_map=None,
split_weights_by_layer=False,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts keras model saved in the SavedModel format to tfjs format.
Expand All @@ -168,8 +167,8 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
output_dir: Output directory to which the TensorFlow.js-format model JSON
file and weights files will be written. If the directory does not exist,
it will be created.
quantization_dtype: The quantized data type to store the weights in
(Default: `None`).
quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
to weights. The weight mapping supports wildcard substitution.
split_weights_by_layer: Whether to split the weights into separate weight
groups (corresponding to separate binary weight files) layer by layer
(Default: `False`).
Expand All @@ -187,7 +186,7 @@ def dispatch_keras_saved_model_to_tensorflowjs_conversion(
dispatch_keras_h5_to_tfjs_layers_model_conversion(
temp_h5_path,
output_dir,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)

Expand Down Expand Up @@ -271,7 +270,7 @@ def dispatch_tensorflowjs_to_keras_saved_model_conversion(
def dispatch_tensorflowjs_to_tensorflowjs_conversion(
config_json_path,
output_dir_path,
quantization_dtype=None,
quantization_dtype_map=None,
weight_shard_size_bytes=1024 * 1024 * 4):
"""Converts a Keras Model from tensorflowjs format to H5.

Expand All @@ -280,8 +279,8 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
topology and weights manifest, in tensorflowjs format.
output_dir_path: Path to output directory in which the result of the
conversion will be saved.
quantization_dtype: The quantized data type to store the weights in
(Default: `None`).
quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
to weights. The weight mapping supports wildcard substitution.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
The size of each weight file will be <= this value.

Expand Down Expand Up @@ -318,7 +317,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
with tf.Graph().as_default(), tf.compat.v1.Session():
dispatch_keras_h5_to_tfjs_layers_model_conversion(
temp_h5_path, output_dir_path,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
# TODO(cais): Support weight quantization.

Expand All @@ -329,7 +328,7 @@ def dispatch_tensorflowjs_to_tensorflowjs_conversion(
def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
config_json_path,
output_dir_path,
quantization_dtype=None,
quantization_dtype_map=None,
skip_op_check=False,
strip_debug_ops=False,
weight_shard_size_bytes=1024 * 1024 * 4):
Expand All @@ -343,8 +342,8 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
topology and weights manifest, in tensorflowjs format.
output_dir_path: Path to output directory in which the result of the
conversion will be saved.
quantization_dtype: The quantized data type to store the weights in
(Default: `None`).
quantization_dtype_map: A mapping from dtype (`uint8`, `uint16`, `float16`)
to weights. The weight mapping supports wildcard substitution.
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to allow unsupported debug ops.
weight_shard_size_bytes: Shard size (in bytes) of the weight files.
Expand Down Expand Up @@ -380,7 +379,7 @@ def dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
model.save(temp_h5_path)
dispatch_keras_h5_to_tfjs_graph_model_conversion(
temp_h5_path, output_dir_path,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=skip_op_check,
strip_debug_ops=strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
Expand Down Expand Up @@ -416,16 +415,32 @@ def _standardize_input_output_formats(input_format, output_format):

return (input_format, output_format)

def _parse_quantization_dtype_map(float16, uint8, uint16, quantization_bytes):
quantization_dtype_map = {}

def _parse_quantization_bytes(quantization_bytes):
if quantization_bytes is None:
return None
elif quantization_bytes == 1:
return np.uint8
elif quantization_bytes == 2:
return np.uint16
else:
raise ValueError('Unsupported quantization bytes: %s' % quantization_bytes)
if quantization_bytes:
print(
'Warning: --quantization_bytes will be deprecated in a future release\n'
'Please consider using --quantize_uint8, --quantize_uint16, '
'--quantize_float16.', file=sys.stderr)
if float16 is not None or uint8 is not None or uint16 is not None:
raise ValueError(
'--quantization_bytes cannot be used with the new quantization flags')

dtype = quantization.QUANTIZATION_BYTES_TO_DTYPES[quantization_bytes]
quantization_dtype_map[dtype] = True

if float16 is not None:
quantization_dtype_map[quantization.QUANTIZATION_DTYPE_FLOAT16] = \
float16.split(',') if isinstance(float16, str) else float16
if uint8 is not None:
quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT8] = \
uint8.split(',') if isinstance(uint8, str) else uint8
if uint16 is not None:
quantization_dtype_map[quantization.QUANTIZATION_DTYPE_UINT16] = \
uint16.split(',') if isinstance(uint16, str) else uint16

return quantization_dtype_map

def get_arg_parser():
"""
Expand Down Expand Up @@ -488,13 +503,43 @@ def get_arg_parser():
help='Tags of the MetaGraphDef to load, in comma separated string '
'format. Defaults to "serve". Applicable only if input format is '
'"tf_saved_model".')
parser.add_argument(
'--%s' % common.QUANTIZATION_TYPE_FLOAT16,
type=str,
default=None,
const=True,
nargs='?',
help='Comma separated list of node names to apply float16 quantization. '
'You can also use wildcard symbol (*) to apply quantization to multiple '
'nodes (e.g., conv/*/weights). When the flag is provided without any '
'nodes the default behavior will match all nodes.')
parser.add_argument(
'--%s' % common.QUANTIZATION_TYPE_UINT8,
type=str,
default=None,
const=True,
nargs='?',
help='Comma separated list of node names to apply 1-byte affine '
'quantization. You can also use wildcard symbol (*) to apply '
'quantization to multiple nodes (e.g., conv/*/weights). When the flag is '
'provided without any nodes the default behavior will match all nodes.')
parser.add_argument(
'--%s' % common.QUANTIZATION_TYPE_UINT16,
type=str,
default=None,
const=True,
nargs='?',
help='Comma separated list of node names to apply 2-byte affine '
'quantization. You can also use wildcard symbol (*) to apply '
'quantization to multiple nodes (e.g., conv/*/weights). When the flag is '
'provided without any nodes the default behavior will match all nodes.')
parser.add_argument(
'--%s' % common.QUANTIZATION_BYTES,
type=int,
choices=set(quantization.QUANTIZATION_BYTES_TO_DTYPES.keys()),
help='How many bytes to optionally quantize/compress the weights to. 1- '
'and 2-byte quantizaton is supported. The default (unquantized) size is '
'4 bytes.')
help='(Deprecated) How many bytes to optionally quantize/compress the '
'weights to. 1- and 2-byte quantizaton is supported. The default '
'(unquantized) size is 4 bytes.')
parser.add_argument(
'--%s' % common.SPLIT_WEIGHTS_BY_LAYER,
action='store_true',
Expand Down Expand Up @@ -574,9 +619,12 @@ def convert(arguments):
'but got %s' % args.weight_shard_size_bytes)
weight_shard_size_bytes = args.weight_shard_size_bytes

quantization_dtype = (
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
if args.quantization_bytes else None)
quantization_dtype_map = _parse_quantization_dtype_map(
args.quantize_float16,
args.quantize_uint8,
args.quantize_uint16,
args.quantization_bytes
)

if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL):
raise ValueError(
Expand All @@ -601,14 +649,14 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_h5_to_tfjs_layers_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.KERAS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_keras_h5_to_tfjs_graph_model_conversion(
args.input_path, output_dir=args.output_path,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
Expand All @@ -617,7 +665,7 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_keras_saved_model_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
split_weights_by_layer=args.split_weights_by_layer,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_SAVED_MODEL and
Expand All @@ -626,7 +674,7 @@ def convert(arguments):
args.input_path, args.output_path,
signature_def=args.signature_name,
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
Expand All @@ -637,7 +685,7 @@ def convert(arguments):
args.input_path, args.output_path,
signature=args.signature_name,
saved_model_tags=args.saved_model_tags,
quantization_dtype=quantization_dtype,
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes,
Expand All @@ -654,21 +702,21 @@ def convert(arguments):
output_format == common.TFJS_LAYERS_MODEL):
dispatch_tensorflowjs_to_tensorflowjs_conversion(
args.input_path, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
quantization_dtype_map=quantization_dtype_map,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TFJS_LAYERS_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
args.input_path, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
elif (input_format == common.TF_FROZEN_MODEL and
output_format == common.TFJS_GRAPH_MODEL):
tf_saved_model_conversion_v2.convert_tf_frozen_model(
args.input_path, args.output_node_names, args.output_path,
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
quantization_dtype_map=quantization_dtype_map,
skip_op_check=args.skip_op_check,
strip_debug_ops=args.strip_debug_ops,
weight_shard_size_bytes=weight_shard_size_bytes)
Expand Down
Loading