Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 0 additions & 23 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -123,29 +123,6 @@ py_test(
],
)

py_test(
name = "wizard_test",
srcs = ["wizard_test.py"],
srcs_version = "PY2AND3",
deps = [
":expect_numpy_installed",
":wizard",
],
)

py_binary(
name = "wizard",
srcs = ["wizard.py"],
srcs_version = "PY2AND3",
deps = [
":converters/common",
":converters/converter",
"//tensorflowjs:expect_h5py_installed",
"//tensorflowjs:expect_keras_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

# A filegroup BUILD target that includes all the op list json files in the
# the op_list/ folder. The op_list folder itself is a symbolic link to the
# actual op_list folder under src/.
Expand Down
1 change: 0 additions & 1 deletion tfjs-converter/python/tensorflowjs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,5 @@
from tensorflowjs import converters
from tensorflowjs import quantization
from tensorflowjs import version
from tensorflowjs import wizard

__version__ = version.version
29 changes: 27 additions & 2 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":fuse_prelu",
":tf_saved_model_conversion_v2",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
Expand Down Expand Up @@ -121,13 +122,37 @@ py_library(
data = ["//tensorflowjs:op_list_jsons"],
srcs_version = "PY2AND3",
deps = [
":common",
":fold_batch_norms",
":fuse_prelu",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:version",
"//tensorflowjs:write_weights",
"//tensorflowjs/converters:common",
"//tensorflowjs/converters:fold_batch_norms",
],
)

py_test(
name = "wizard_test",
srcs = ["wizard_test.py"],
srcs_version = "PY2AND3",
deps = [
":wizard",
"//tensorflowjs:expect_numpy_installed",
],
)

py_binary(
name = "wizard",
srcs = ["wizard.py"],
srcs_version = "PY2AND3",
deps = [
":common",
":converter",
":fuse_prelu",
"//tensorflowjs:expect_h5py_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand Down
1 change: 1 addition & 0 deletions tfjs-converter/python/tensorflowjs/converters/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,5 +652,6 @@ def pip_main():
def main(argv):
convert(argv[0].split(' '))


if __name__ == '__main__':
tf.app.run(main=main, argv=[' '.join(sys.argv[1:])])
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
PyInquirer.Token.Question: '',
})


def value_in_list(answers, key, values):
"""Determine user's answer for the key is in the value list.
Args:
Expand All @@ -70,6 +71,7 @@ def get_tfjs_model_type(model_file):
else: # Default to layers model
return common.TFJS_LAYERS_MODEL_FORMAT


def detect_saved_model(input_path):
if os.path.exists(os.path.join(input_path, 'assets', 'saved_model.json')):
return common.KERAS_SAVED_MODEL
Expand All @@ -80,6 +82,7 @@ def detect_saved_model(input_path):
return common.KERAS_SAVED_MODEL
return common.TF_SAVED_MODEL


def detect_input_format(input_path):
"""Determine the input format from model's input path or file.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from tensorflow.python.training.tracking import tracking
from tensorflow.python.saved_model import save

from tensorflowjs import wizard
from tensorflowjs.converters import wizard

SAVED_MODEL_DIR = 'saved_model'
SAVED_MODEL_NAME = 'saved_model.pb'
Expand Down