Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tfjs-converter/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def _get_requirements(file):
'tensorflowjs.converters.common',
'tensorflowjs.converters.converter',
'tensorflowjs.converters.fold_batch_norms',
'tensorflowjs.converters.fuse_depthwise_conv2d',
'tensorflowjs.converters.fuse_prelu',
'tensorflowjs.converters.graph_rewrite_util',
'tensorflowjs.converters.keras_h5_conversion',
'tensorflowjs.converters.keras_tfjs_loader',
'tensorflowjs.converters.tf_saved_model_conversion_v2',
Expand Down
35 changes: 32 additions & 3 deletions tfjs-converter/python/tensorflowjs/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,19 @@ py_test(
],
)

py_library(
name = "graph_rewrite_util",
srcs = ["graph_rewrite_util.py"],
srcs_version = "PY2AND3",
deps = ["//tensorflowjs:version"],
)

py_library(
name = "fuse_prelu",
srcs = ["fuse_prelu.py"],
srcs_version = "PY2AND3",
deps = [
":common",
":graph_rewrite_util",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
Expand All @@ -88,7 +95,7 @@ py_library(
srcs = ["fold_batch_norms.py"],
srcs_version = "PY2AND3",
deps = [
":common",
":graph_rewrite_util",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
Expand All @@ -105,6 +112,28 @@ py_test(
],
)

py_library(
name = "fuse_depthwise_conv2d",
srcs = ["fuse_depthwise_conv2d.py"],
srcs_version = "PY2AND3",
deps = [
":graph_rewrite_util",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "fuse_depthwise_conv2d_test",
srcs = ["fuse_depthwise_conv2d_test.py"],
srcs_version = "PY2AND3",
deps = [
":fuse_depthwise_conv2d",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
)

py_test(
name = "tf_saved_model_conversion_v2_test",
srcs = ["tf_saved_model_conversion_v2_test.py"],
Expand All @@ -125,6 +154,7 @@ py_library(
":common",
":fold_batch_norms",
":fuse_prelu",
":fuse_depthwise_conv2d",
"//tensorflowjs:expect_numpy_installed",
"//tensorflowjs:expect_tensorflow_installed",
"//tensorflowjs:expect_tensorflow_hub_installed",
Expand All @@ -151,7 +181,6 @@ py_binary(
deps = [
":common",
":converter",
":fuse_prelu",
"//tensorflowjs:expect_h5py_installed",
"//tensorflowjs:expect_tensorflow_installed",
],
Expand Down
59 changes: 0 additions & 59 deletions tfjs-converter/python/tensorflowjs/converters/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import re

from tensorflow.python.framework import tensor_util

from tensorflowjs import version


Expand Down Expand Up @@ -64,57 +59,3 @@
def get_converted_by():
"""Get the convertedBy string for storage in model artifacts."""
return 'TensorFlow.js Converter v%s' % version.version

def node_from_map(node_map, name):
"""Pulls a node def from a dictionary for a given name.

Args:
node_map: Dictionary containing an entry indexed by name for every node.
name: Identifies the node we want to find.

Returns:
NodeDef of the node with the given name.

Raises:
ValueError: If the node isn't present in the dictionary.
"""
stripped_name = node_name_from_input(name)
if stripped_name not in node_map:
raise ValueError("No node named '%s' found in map." % name)
return node_map[stripped_name]


def values_from_const(node_def):
"""Extracts the values from a const NodeDef as a numpy ndarray.

Args:
node_def: Const NodeDef that has the values we want to access.

Returns:
Numpy ndarray containing the values.

Raises:
ValueError: If the node isn't a Const.
"""
if node_def.op != "Const":
raise ValueError(
"Node named '%s' should be a Const op for values_from_const." %
node_def.name)
input_tensor = node_def.attr["value"].tensor
tensor_value = tensor_util.MakeNdarray(input_tensor)
return tensor_value

# Whether to scale by gamma after normalization.
def scale_after_normalization(node):
if node.op == "BatchNormWithGlobalNormalization":
return node.attr["scale_after_normalization"].b
return True

def node_name_from_input(node_name):
"""Strips off ports and other decorations to get the underlying node name."""
if node_name.startswith("^"):
node_name = node_name[1:]
m = re.search(r"(.*):\d+$", node_name)
if m:
node_name = m.group(1)
return node_name
25 changes: 13 additions & 12 deletions tfjs-converter/python/tensorflowjs/converters/fold_batch_norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import tf_logging

from tensorflowjs.converters import common
from tensorflowjs.converters import graph_rewrite_util

INPUT_ORDER = {
# Order of inputs for BatchNormWithGlobalNormalization.
Expand Down Expand Up @@ -84,80 +84,81 @@ def fold_batch_norms(input_graph_def):
"FusedBatchNorm", "FusedBatchNormV3")):
continue

conv_op = common.node_from_map(
conv_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("conv_op")])
if conv_op.op != "Conv2D" and conv_op.op != "DepthwiseConv2dNative":
tf_logging.warning("Didn't find expected Conv2D or DepthwiseConv2dNative"
" input to '%s'" % node.name)
continue

weights_op = common.node_from_map(input_node_map, conv_op.input[1])
weights_op = graph_rewrite_util.node_from_map(
input_node_map, conv_op.input[1])
if weights_op.op != "Const":
tf_logging.warning("Didn't find expected conv Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (conv_op.name, weights_op))
continue
weights = common.values_from_const(weights_op)
weights = graph_rewrite_util.values_from_const(weights_op)
if conv_op.op == "Conv2D":
channel_count = weights.shape[3]
elif conv_op.op == "DepthwiseConv2dNative":
channel_count = weights.shape[2] * weights.shape[3]

mean_op = common.node_from_map(
mean_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("mean_op")])
if mean_op.op != "Const":
tf_logging.warning("Didn't find expected mean Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, mean_op))
continue
mean_value = common.values_from_const(mean_op)
mean_value = graph_rewrite_util.values_from_const(mean_op)
if mean_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for mean, found %s, expected %s,"
" for node %s" % (str(mean_value.shape), str(
(channel_count,)), node.name))
continue

var_op = common.node_from_map(
var_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("var_op")])
if var_op.op != "Const":
tf_logging.warning("Didn't find expected var Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, var_op))
continue
var_value = common.values_from_const(var_op)
var_value = graph_rewrite_util.values_from_const(var_op)
if var_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for var, found %s, expected %s,"
" for node %s" % (str(var_value.shape), str(
(channel_count,)), node.name))
continue

beta_op = common.node_from_map(
beta_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("beta_op")])
if beta_op.op != "Const":
tf_logging.warning("Didn't find expected beta Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, beta_op))
continue
beta_value = common.values_from_const(beta_op)
beta_value = graph_rewrite_util.values_from_const(beta_op)
if beta_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for beta, found %s, expected %s,"
" for node %s" % (str(beta_value.shape), str(
(channel_count,)), node.name))
continue

gamma_op = common.node_from_map(
gamma_op = graph_rewrite_util.node_from_map(
input_node_map,
node.input[INPUT_ORDER[node.op].index("gamma_op")])
if gamma_op.op != "Const":
tf_logging.warning("Didn't find expected gamma Constant input to '%s',"
" found %s instead. Maybe because freeze_graph wasn't"
" run first?" % (node.name, gamma_op))
continue
gamma_value = common.values_from_const(gamma_op)
gamma_value = graph_rewrite_util.values_from_const(gamma_op)
if gamma_value.shape != (channel_count,):
tf_logging.warning("Incorrect shape for gamma, found %s, expected %s,"
" for node %s" % (str(gamma_value.shape), str(
Expand Down
Loading