Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import function

from tensorflowjs.converters import graph_rewrite_util

def _is_supported_activation(node):
Expand Down Expand Up @@ -149,41 +143,3 @@ def _fuse_depthwise_conv2d_with_match_function(input_graph_def, match_function):

# No pattern detected
return input_graph_def

def extract_op_attributes(input_graph_def):
"""Since TF does not allow function defined custom op to have any attributes,
we need to clean up the attributes for FusedDepthwiseConv2dNative op.
Args:
input_graph_def: Modified tf.Graph object.
"""
result_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
new_node = node_def_pb2.NodeDef()
new_node.CopyFrom(node)
if new_node.op == graph_rewrite_util.FUSED_DEPTHWISE_CONV2D:
new_node.ClearField('attr')
if len(new_node.input) > 3:
new_node.input[:] = new_node.input[0:3]
result_graph_def.node.extend([new_node])
result_graph_def.versions.CopyFrom(input_graph_def.versions)
return result_graph_def

def register_fused_depthwise_conv2d_func(graph):
"""Register FusedDepthwiseConv2dNative op with function def, this is needed
for importing graph_def with unregistered op.
Args:
graph: A tf.Graph object to insert FusedDepthwiseConv2d function into.
"""

# Create a function for FusedDepthwiseConv2dNative op
@function.Defun(tf.float32, tf.float32, tf.float32,
func_name=graph_rewrite_util.FUSED_DEPTHWISE_CONV2D)
def fused_depthwise_conv2d_fn(*args):
# This is a placeholder for the Op definition, the exact implemenation of the
# function does not matter.
return tf.nn.depthwise_conv2d(
args[0], args[1], strides=[1, 1, 1, 1], padding='SAME')
# Insert the function into graph
with graph.as_default():
fused_depthwise_conv2d_fn(
tf.ones([1, 1, 1]), tf.ones([1, 1, 1]), tf.ones([1]))
39 changes: 1 addition & 38 deletions tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,11 @@
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import tensor_util

from tensorflowjs.converters import graph_rewrite_util

# The function is only needed for TensorFlow 1.X and 2.0. Remove once tfjs
# no longer depends on these versions.
def register_prelu_op():
"""Register a virtual PReLU OpDef.

This allows to bypass MetaGraph validity checks on TensorFlow 1.X and 2.0.
"""

prelu_op_def = op_def_pb2.OpDef()
prelu_op_def.name = 'Prelu'
missing_op_list = op_def_pb2.OpList()
missing_op_list.op.extend([prelu_op_def])
op_def_registry.register_op_list(missing_op_list)

def fuse_ops_for_prelu(input_graph_def):
"""Modifies the provided graph by fusing a set of ops into a single Prelu op.
The formula of PReLU is:
Expand Down Expand Up @@ -208,22 +189,4 @@ def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def):

return graph_rewrite_util.cleanup_graph_def(
input_graph_def, nodes_to_skip, inputs_to_remove)

def register_prelu_func(graph):
"""Register Prelu op with function def, this is needed for importing graph_def
with unregistered Prelu op.
Args:
graph: A tf.Graph object to insert prelu function into.
"""

# Create a function for Prelu op
@function.Defun(tf.float32, tf.float32, func_name='Prelu')
def prelu_fn(*args):
return tf.add(args[0], args[1])
# Insert the function into graph
with graph.as_default():
prelu_fn(tf.constant(1.0), tf.constant(1.0))


if hasattr(op_def_registry, 'register_op_list'):
register_prelu_op()

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import json
import os

import numpy as np
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
Expand All @@ -42,6 +41,7 @@
from tensorflowjs.converters import fold_batch_norms
from tensorflowjs.converters import fuse_prelu
from tensorflowjs.converters import fuse_depthwise_conv2d
from tensorflowjs.converters import graph_rewrite_util
from tensorflowjs import resource_loader

# enable eager execution for v2 APIs
Expand Down Expand Up @@ -126,7 +126,6 @@ def optimize_graph(graph, signature_def, output_graph,
skip_op_check: Bool whether to skip the op check.
strip_debug_ops: Bool whether to strip debug ops.
"""
fuse_prelu.register_prelu_func(graph)

# Add a collection 'train_op' so that Grappler knows the outputs.
for _, output in signature_def.outputs.items():
Expand Down Expand Up @@ -221,27 +220,17 @@ def extract_weights(graph_def,
print('Writing weight file ' + output_graph + '...')
const_manifest = []

graph = tf.Graph()
fuse_prelu.register_prelu_func(graph)
fuse_depthwise_conv2d.register_fused_depthwise_conv2d_func(graph)

extracted_graph = fuse_depthwise_conv2d.extract_op_attributes(graph_def)
with tf.compat.v1.Session(graph=graph) as sess:
tf.import_graph_def(extracted_graph, name='')
for const in constants:
tensor = graph.get_tensor_by_name(const.name + ':0')
value = tensor.eval(session=sess)
if not isinstance(value, np.ndarray):
value = np.array(value)

const_manifest.append({'name': const.name, 'data': value})

# Restore the conditional inputs
const.input[:] = const_inputs[const.name]

# Remove the binary array from tensor and save it to the external file.
for field_name in CLEARED_TENSOR_FIELDS:
const.attr["value"].tensor.ClearField(field_name)
for const in constants:
const_manifest.append({
'name': const.name,
'data': graph_rewrite_util.values_from_const(const)
})
# Restore the conditional inputs
const.input[:] = const_inputs[const.name]

# Remove the binary array from tensor and save it to the external file.
for field_name in CLEARED_TENSOR_FIELDS:
const.attr["value"].tensor.ClearField(field_name)

write_artifacts(MessageToDict(graph_def), [const_manifest], output_graph,
tf_version, signature_def,
Expand Down