From a0832e15f06b16219f0ed00cfc4e9c60079d8db8 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 25 Nov 2019 14:29:33 -0800 Subject: [PATCH 1/2] add output format auto fill for frozen model and required param check --- tfjs-converter/python/tensorflowjs/converters/converter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index c81167373e1..729274bab46 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -390,7 +390,8 @@ def _standardize_input_output_formats(input_format, output_format): input_format_is_keras = ( input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL]) input_format_is_tf = ( - input_format in [common.TF_SAVED_MODEL, common.TF_HUB_MODEL]) + input_format in [common.TF_SAVED_MODEL, + common.TF_FROZEN_MODEL, common.TF_HUB_MODEL]) if output_format is None: # If no explicit output_format is provided, infer it from input format. if input_format_is_keras: @@ -552,6 +553,10 @@ def convert(arguments): quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes] if args.quantization_bytes else None) + if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL): + raise ValueError( + 'The --output_node_names flag is required for "tf_frozeon_model"') + if (args.signature_name and input_format not in (common.TF_SAVED_MODEL, common.TF_HUB_MODEL)): raise ValueError( From 43ba3364c4783e253a2d05c2680be25cb6cbd0a7 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 25 Nov 2019 14:52:17 -0800 Subject: [PATCH 2/2] fix typo --- tfjs-converter/python/tensorflowjs/converters/converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index 729274bab46..0bf06cd7928 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -555,7 +555,7 @@ def convert(arguments): if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL): raise ValueError( - 'The --output_node_names flag is required for "tf_frozeon_model"') + 'The --output_node_names flag is required for "tf_frozen_model"') if (args.signature_name and input_format not in (common.TF_SAVED_MODEL, common.TF_HUB_MODEL)):