Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tfjs-converter/python/tensorflowjs/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ py_library(
name = "tensorflowjs",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":quantization",
":version",
"//tensorflowjs/converters:converter"
"//tensorflowjs/converters:converter",
],
visibility = ["//visibility:public"],
)

py_library(
Expand Down Expand Up @@ -72,6 +72,7 @@ py_library(
":quantization",
":read_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand Down Expand Up @@ -115,6 +116,7 @@ py_test(
deps = [
":write_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand All @@ -126,14 +128,15 @@ py_test(
":read_weights",
":write_weights",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "resource_loader_test",
srcs = ["resource_loader_test.py"],
srcs_version = "PY2AND3",
data = [":op_list_jsons"],
srcs_version = "PY2AND3",
deps = [
":resource_loader",
],
Expand Down
28 changes: 22 additions & 6 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":tf_saved_model_conversion_v2",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

Expand All @@ -154,11 +154,11 @@ py_library(
deps = [
":common",
":fold_batch_norms",
":fuse_prelu",
":fuse_depthwise_conv2d",
":fuse_prelu",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:resource_loader",
"//tensorflowjs:version",
"//tensorflowjs:write_weights",
Expand Down Expand Up @@ -187,10 +187,27 @@ py_binary(
],
)

py_library(
name = "converter_lib",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":common",
":keras_h5_conversion",
":keras_tfjs_loader",
":tf_saved_model_conversion_v2",
"//third_party/py/h5py",
"//third_party/py/tensorflow",
"//third_party/py/tensorflowjs:version",
],
)

py_binary(
name = "converter",
srcs = ["converter.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":common",
":keras_h5_conversion",
Expand All @@ -200,17 +217,16 @@ py_binary(
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:version",
],
visibility = ["//visibility:public"],
)

py_binary(
name = "generate_test_model",
srcs = ["generate_test_model.py"],
testonly = True,
srcs = ["generate_test_model.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflowjs:expect_tensorflow_installed",
]
],
)

py_test(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
import shutil
import tempfile
import unittest

import numpy as np
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -446,4 +445,4 @@ def testLoadFunctionalTfKerasModel(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import device_properties_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import convert_to_constants
from tensorflow.python.grappler import cluster as gcluster
from tensorflow.python.grappler import tf_optimizer
Expand Down Expand Up @@ -285,7 +286,7 @@ def write_artifacts(topology,
assert isinstance(weights_manifest, list)
model_json[common.ARTIFACT_WEIGHTS_MANIFEST_KEY] = weights_manifest

with open(output_graph, 'wt') as f:
with tf.io.gfile.GFile(output_graph, 'w') as f:
json.dump(model_json, f)

def _remove_unused_control_flow_inputs(input_graph_def):
Expand Down Expand Up @@ -426,14 +427,17 @@ def convert_tf_saved_model(saved_model_dir,
if signature_def is None:
signature_def = 'serving_default'

if not os.path.exists(output_dir):
os.makedirs(output_dir)
if not tf.io.gfile.exists(output_dir):
tf.io.gfile.makedirs(output_dir)
output_graph = os.path.join(
output_dir, common.ARTIFACT_MODEL_JSON_FILE_NAME)

if saved_model_tags:
saved_model_tags = saved_model_tags.split(',')
model = load(saved_model_dir, saved_model_tags)
model = None
# Ensure any graphs created in eager mode are able to run.
with context.eager_mode():
model = load(saved_model_dir, saved_model_tags)

_check_signature_in_model(model, signature_def)

Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/python/tensorflowjs/read_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@

import os
import shutil
import unittest

import tempfile

import numpy as np
import tensorflow as tf

from tensorflowjs import read_weights
from tensorflowjs import write_weights


class ReadWeightsTest(unittest.TestCase):
class ReadWeightsTest(tf.test.TestCase):
def setUp(self):
self._tmp_dir = tempfile.mkdtemp()
super(ReadWeightsTest, self).setUp()
Expand Down Expand Up @@ -342,4 +342,4 @@ def testReadQuantizedWeights(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()
5 changes: 3 additions & 2 deletions tfjs-converter/python/tensorflowjs/write_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import os

import numpy as np
import tensorflow as tf

from tensorflowjs import quantization
from tensorflowjs import read_weights
Expand Down Expand Up @@ -133,7 +134,7 @@ def write_weights(

if write_manifest:
manifest_path = os.path.join(write_dir, 'weights_manifest.json')
with open(manifest_path, 'wb') as f:
with tf.io.gfile.GFile(manifest_path, 'wb') as f:
f.write(json.dumps(manifest).encode())

return manifest
Expand Down Expand Up @@ -291,7 +292,7 @@ def _shard_group_bytes_to_disk(
filepath = os.path.join(write_dir, filename)

# Write the shard to disk.
with open(filepath, 'wb') as f:
with tf.io.gfile.GFile(filepath, 'wb') as f:
f.write(shard)

return filenames
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/python/tensorflowjs/write_weights_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@

import os
import shutil
import unittest

import numpy as np
import tensorflow as tf

from tensorflowjs import write_weights

TMP_DIR = '/tmp/write_weights_test/'


class TestWriteWeights(unittest.TestCase):
class TestWriteWeights(tf.test.TestCase):
def setUp(self):
if not os.path.isdir(TMP_DIR):
os.makedirs(TMP_DIR)
Expand Down Expand Up @@ -751,4 +751,4 @@ def test_quantize_group(self):


if __name__ == '__main__':
unittest.main()
tf.test.main()