diff --git a/tfjs-converter/python/tensorflowjs/converters/wizard.py b/tfjs-converter/python/tensorflowjs/converters/wizard.py index 990a0598934..457e0de0eac 100644 --- a/tfjs-converter/python/tensorflowjs/converters/wizard.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard.py @@ -264,6 +264,10 @@ def available_output_formats(answers): 'key': 's', 'name': 'Keras Saved Model', 'value': common.KERAS_SAVED_MODEL, + }, { + 'key': 'l', + 'name': 'TensoFlow.js Layers Model', + 'value': common.TFJS_LAYERS_MODEL, }] return [] @@ -486,8 +490,11 @@ def run(dryrun): 'name': common.SPLIT_WEIGHTS_BY_LAYER, 'message': 'Do you want to split weights by layers?', 'default': False, - 'when': lambda answers: value_in_list(answers, common.INPUT_FORMAT, - (common.TFJS_LAYERS_MODEL)) + 'when': lambda answers: (value_in_list(answers, common.OUTPUT_FORMAT, + (common.TFJS_LAYERS_MODEL)) and + value_in_list(answers, common.INPUT_FORMAT, + (common.KERAS_MODEL, + common.KERAS_SAVED_MODEL))) }, { 'type': 'confirm', @@ -577,10 +584,10 @@ def pip_main(): def main(argv): - if len(argv) > 2 or len(argv) == 2 and not argv[1] == '--dryrun': + if argv[0] and not argv[0] == '--dryrun': print("Usage: tensorflowjs_wizard [--dryrun]") sys.exit(1) - dry_run = len(argv) == 2 and argv[1] == '--dryrun' + dry_run = argv[0] == '--dryrun' run(dry_run) if __name__ == '__main__':