diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index c81167373e1..0bf06cd7928 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -390,7 +390,8 @@ def _standardize_input_output_formats(input_format, output_format): input_format_is_keras = ( input_format in [common.KERAS_MODEL, common.KERAS_SAVED_MODEL]) input_format_is_tf = ( - input_format in [common.TF_SAVED_MODEL, common.TF_HUB_MODEL]) + input_format in [common.TF_SAVED_MODEL, + common.TF_FROZEN_MODEL, common.TF_HUB_MODEL]) if output_format is None: # If no explicit output_format is provided, infer it from input format. if input_format_is_keras: @@ -552,6 +553,10 @@ def convert(arguments): quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes] if args.quantization_bytes else None) + if (not args.output_node_names and input_format == common.TF_FROZEN_MODEL): + raise ValueError( + 'The --output_node_names flag is required for "tf_frozen_model"') + if (args.signature_name and input_format not in (common.TF_SAVED_MODEL, common.TF_HUB_MODEL)): raise ValueError(