diff --git a/tfjs-converter/python/tensorflowjs/BUILD b/tfjs-converter/python/tensorflowjs/BUILD index b1d883ae850..cda74529acd 100644 --- a/tfjs-converter/python/tensorflowjs/BUILD +++ b/tfjs-converter/python/tensorflowjs/BUILD @@ -123,29 +123,6 @@ py_test( ], ) -py_test( - name = "wizard_test", - srcs = ["wizard_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":expect_numpy_installed", - ":wizard", - ], -) - -py_binary( - name = "wizard", - srcs = ["wizard.py"], - srcs_version = "PY2AND3", - deps = [ - ":converters/common", - ":converters/converter", - "//tensorflowjs:expect_h5py_installed", - "//tensorflowjs:expect_keras_installed", - "//tensorflowjs:expect_tensorflow_installed", - ], -) - # A filegroup BUILD target that includes all the op list json files in the # the op_list/ folder. The op_list folder itself is a symbolic link to the # actual op_list folder under src/. diff --git a/tfjs-converter/python/tensorflowjs/__init__.py b/tfjs-converter/python/tensorflowjs/__init__.py index d120fa33c0d..946f94954b2 100644 --- a/tfjs-converter/python/tensorflowjs/__init__.py +++ b/tfjs-converter/python/tensorflowjs/__init__.py @@ -21,6 +21,5 @@ from tensorflowjs import converters from tensorflowjs import quantization from tensorflowjs import version -from tensorflowjs import wizard __version__ = version.version diff --git a/tfjs-converter/python/tensorflowjs/converters/BUILD b/tfjs-converter/python/tensorflowjs/converters/BUILD index 42bec0e2edc..ae00c03aad9 100644 --- a/tfjs-converter/python/tensorflowjs/converters/BUILD +++ b/tfjs-converter/python/tensorflowjs/converters/BUILD @@ -77,6 +77,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":fuse_prelu", + ":tf_saved_model_conversion_v2", "//tensorflowjs:expect_numpy_installed", "//tensorflowjs:expect_tensorflow_installed", ], @@ -121,13 +122,37 @@ py_library( data = ["//tensorflowjs:op_list_jsons"], srcs_version = "PY2AND3", deps = [ + ":common", + ":fold_batch_norms", + ":fuse_prelu", "//tensorflowjs:expect_numpy_installed", "//tensorflowjs:expect_tensorflow_installed", "//tensorflowjs:expect_tensorflow_hub_installed", "//tensorflowjs:version", "//tensorflowjs:write_weights", - "//tensorflowjs/converters:common", - "//tensorflowjs/converters:fold_batch_norms", + ], +) + +py_test( + name = "wizard_test", + srcs = ["wizard_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":wizard", + "//tensorflowjs:expect_numpy_installed", + ], +) + +py_binary( + name = "wizard", + srcs = ["wizard.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ":converter", + ":fuse_prelu", + "//tensorflowjs:expect_h5py_installed", + "//tensorflowjs:expect_tensorflow_installed", ], ) diff --git a/tfjs-converter/python/tensorflowjs/converters/converter.py b/tfjs-converter/python/tensorflowjs/converters/converter.py index 000a5e803c9..57cf1639232 100644 --- a/tfjs-converter/python/tensorflowjs/converters/converter.py +++ b/tfjs-converter/python/tensorflowjs/converters/converter.py @@ -652,5 +652,6 @@ def pip_main(): def main(argv): convert(argv[0].split(' ')) + if __name__ == '__main__': tf.app.run(main=main, argv=[' '.join(sys.argv[1:])]) diff --git a/tfjs-converter/python/tensorflowjs/wizard.py b/tfjs-converter/python/tensorflowjs/converters/wizard.py similarity index 99% rename from tfjs-converter/python/tensorflowjs/wizard.py rename to tfjs-converter/python/tensorflowjs/converters/wizard.py index bb394953f38..6fe8c198ad1 100644 --- a/tfjs-converter/python/tensorflowjs/wizard.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard.py @@ -47,6 +47,7 @@ PyInquirer.Token.Question: '', }) + def value_in_list(answers, key, values): """Determine user's answer for the key is in the value list. Args: @@ -70,6 +71,7 @@ def get_tfjs_model_type(model_file): else: # Default to layers model return common.TFJS_LAYERS_MODEL_FORMAT + def detect_saved_model(input_path): if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')): return common.KERAS_SAVED_MODEL @@ -80,6 +82,7 @@ def detect_saved_model(input_path): return common.KERAS_SAVED_MODEL return common.TF_SAVED_MODEL + def detect_input_format(input_path): """Determine the input format from model's input path or file. Args: diff --git a/tfjs-converter/python/tensorflowjs/wizard_test.py b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py similarity index 99% rename from tfjs-converter/python/tensorflowjs/wizard_test.py rename to tfjs-converter/python/tensorflowjs/converters/wizard_test.py index 706aeb7a336..def05f4bc41 100644 --- a/tfjs-converter/python/tensorflowjs/wizard_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/wizard_test.py @@ -28,7 +28,7 @@ from tensorflow.python.training.tracking import tracking from tensorflow.python.saved_model import save -from tensorflowjs import wizard +from tensorflowjs.converters import wizard SAVED_MODEL_DIR = 'saved_model' SAVED_MODEL_NAME = 'saved_model.pb'