Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _create_alpha_node(neg_alpha_op, updated_alpha):
alpha_value, alpha_value.dtype.type, alpha_value.shape)))
updated_alpha.append(neg_alpha_op.name)

def fuse_prelu_with_fused_conv2d(input_graph_def):
def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def):
"""Tensorflow does not support Prelu op, and the grappler remap optimizer
will not fuse the prelu op with _FusedConv2D op. This method searches for
the pattern and fuse the (_FusedConv2D||FusedDepthwiseConv2dNative + Prelu)
Expand Down Expand Up @@ -187,19 +187,20 @@ def fuse_prelu_with_fused_conv2d(input_graph_def):
if node.op != 'Prelu':
continue

fused_conv_op = graph_rewrite_util.node_from_map(
fused_op = graph_rewrite_util.node_from_map(
input_node_map, node.input[0])
if (not fused_conv_op or
(fused_conv_op.op != '_FusedConv2D'
and fused_conv_op.op != 'FusedDepthwiseConv2dNative') or
len(fused_conv_op.attr['fused_ops'].list.s) > 1):
if (not fused_op or
(fused_op.op != '_FusedConv2D'
and fused_op.op != '_FusedMatMul'
and fused_op.op != 'FusedDepthwiseConv2dNative') or
len(fused_op.attr['fused_ops'].list.s) > 1):
continue

alpha_tensor_name = node.input[1]

fused_conv_op.input.extend([alpha_tensor_name])
fused_conv_op.attr['fused_ops'].list.s.extend([b'Prelu'])
fused_conv_op.attr['num_args'].i = fused_conv_op.attr['num_args'].i + 1
fused_op.input.extend([alpha_tensor_name])
fused_op.attr['fused_ops'].list.s.extend([b'Prelu'])
fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1
node.op = 'Identity'
node.input[:] = [node.input[0]]
nodes_to_skip[node.name] = True
Expand Down
56 changes: 53 additions & 3 deletions tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def execute_model(tensor):
config, graph_def, graph, signature)
graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def)

optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def)
optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul(
graph_def)

conv2d_op = None
for node in optimized_graph_def.node:
Expand All @@ -135,6 +136,55 @@ def execute_model(tensor):
self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu'])
self.assertEqual(conv2d_op.attr['num_args'].i, 2)

def testFusePreluWithMatMul(self):
layers = [
tf.keras.layers.Dense(
2, use_bias=True,
kernel_initializer=tf.initializers.constant(0.25),
bias_initializer=tf.initializers.constant(0.25)),
tf.keras.layers.PReLU()
]
model = tf.keras.Sequential(layers)
tf.keras.backend.set_learning_phase(0)
input_tensor = tf.constant([1.0, 1.0], shape=[1, 2])

@tf.function
def execute_model(tensor):
return model(tensor)

graph = tf_saved_model_conversion_v2._freeze_saved_model_v2(
execute_model.get_concrete_function(input_tensor))
graph_def = graph.as_graph_def()
for node in graph_def.node:
if node.op == 'MatMul':
node.device = "/CPU:0"

config = config_pb2.ConfigProto()
rewriter_config = config.graph_options.rewrite_options
rewriter_config.optimizers[:] = [
'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'remap',
'constfold', 'arithmetic', 'dependency'
]

for output in ['Identity']:
graph.add_to_collection('train_op', graph.get_operation_by_name(output))

signature = meta_graph_pb2.SignatureDef()
graph_def = tf_saved_model_conversion_v2._run_grappler(
config, graph_def, graph, signature)
graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def)
optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul(
graph_def)

matmul_op = None
for node in optimized_graph_def.node:
self.assertNotEqual("Prelu", node.op)
if node.op == '_FusedMatMul':
matmul_op = node
self.assertNotEqual(matmul_op, None)
self.assertEqual(matmul_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu'])
self.assertEqual(matmul_op.attr['num_args'].i, 2)

def testFusePreluWithDepthwiseConv2d(self):
layers = [
tf.keras.layers.DepthwiseConv2D(
Expand Down Expand Up @@ -173,7 +223,8 @@ def execute_model(tensor):
graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def)
graph_def = fuse_depthwise_conv2d.fuse_depthwise_conv2d(graph_def)

optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def)
optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul(
graph_def)

conv2d_op = None
for node in optimized_graph_def.node:
Expand All @@ -183,6 +234,5 @@ def execute_model(tensor):
self.assertNotEqual(conv2d_op, None)
self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu'])
self.assertEqual(conv2d_op.attr['num_args'].i, 2)

if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

# Custom op name for fused depthwise conv2d
FUSED_DEPTHWISE_CONV2D = 'FusedDepthwiseConv2dNative'
# The grappler op name for fused MatMul which starts with '_'
FUSED_MATMUL = '_FusedMatMul'

def node_from_map(node_map, name):
"""Pulls a node def from a dictionary for a given name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ def optimize_graph(graph, signature_def, output_graph,
# batch norm folding
optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph)

# set the device to CPU for all Conv2d nodes, since grappler remap optimizer
# only support FusedConv2D for CPU.
# set the device to CPU for all Conv2d and MatMul nodes, since grappler
# remap optimizer only support FusedConv2D and FusedMatMul for CPU.
for node in optimized_graph.node:
if node.op == 'Conv2D':
if node.op == 'Conv2D' or node.op == 'MatMul':
node.device = '/device:CPU:0'

# rerun grappler to fuse conv2d
# rerun grappler to fuse conv2d/matmul
config.graph_options.rewrite_options.optimizers[:] = [
'remap',
'constfold', 'arithmetic', 'dependency'
Expand All @@ -181,7 +181,8 @@ def optimize_graph(graph, signature_def, output_graph,

# Since the grappler remap optimizer doe snot support prelu as the activation
# function for _FusedConv2D op, we are doing it manually here.
optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d(optimized_graph)
optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul(
optimized_graph)

unsupported = validate(optimized_graph.node, skip_op_check,
strip_debug_ops)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,20 @@ def _create_saved_model(self):
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
save(root, save_dir, to_save)

def _create_saved_model_with_fusable_matmul(self):
"""Test a fusable matmul model."""
input_data = constant_op.constant(1., shape=[1, 1])
bias_data = constant_op.constant(1., shape=[1])
root = tracking.AutoTrackable()
root.v2 = variables.Variable([[2.]])
root.f = def_function.function(
lambda x: tf.nn.relu(tf.nn.bias_add(tf.matmul(x, root.v2),
bias_data)))
to_save = root.f.get_concrete_function(input_data)

save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
save(root, save_dir, to_save)

def _create_saved_model_with_control_flow(self):
"""Test a basic model with control flow to inlined."""
@tf.function
Expand Down Expand Up @@ -441,6 +455,53 @@ def test_convert_saved_model_with_fused_conv2d(self):
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

def test_convert_saved_model_with_fused_matmul(self):
self._create_saved_model_with_fusable_matmul()
tf_saved_model_conversion_v2.convert_tf_saved_model(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR),
os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
)

tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
# Check model.json and weights manifest.
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
model_json = json.load(f)
self.assertTrue(model_json['modelTopology'])
self.assertIsNot(model_json['modelTopology']['versions'], None)
signature = model_json['userDefinedMetadata']['signature']
self.assertIsNot(signature, None)
self.assertIsNot(signature['inputs'], None)
self.assertIsNot(signature['outputs'], None)

nodes = model_json['modelTopology']['node']
fusedOp = None
for node in nodes:
self.assertTrue(node['op'] != 'MatMul')
self.assertTrue(not 'Relu' in node['op'])
self.assertTrue(not 'BiasAdd' in node['op'])
if node['op'] == graph_rewrite_util.FUSED_MATMUL:
fusedOp = node
self.assertTrue(fusedOp is not None)
self.assertIsNot(fusedOp['attr']['transpose_a'], None)
self.assertIsNot(fusedOp['attr']['transpose_b'], None)
self.assertEqual(
base64.b64decode(fusedOp['attr']['fused_ops']['list']['s'][0]),
b'BiasAdd')
self.assertEqual(
base64.b64decode(fusedOp['attr']['fused_ops']['list']['s'][1]),
b'Relu')

# Check meta-data in the artifact JSON.
self.assertEqual(model_json['format'], 'graph-model')
self.assertEqual(
model_json['convertedBy'],
'TensorFlow.js Converter v%s' % version.version)
self.assertEqual(model_json['generatedBy'],
tf.__version__)
self.assertTrue(
glob.glob(
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))

def test_convert_saved_model_with_fused_depthwise_conv2d(self):
self._create_saved_model_with_fusable_depthwise_conv2d()
tf_saved_model_conversion_v2.convert_tf_saved_model(
Expand Down
59 changes: 59 additions & 0 deletions tfjs-converter/python/tensorflowjs/op_list/matrices.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,63 @@
[
{
"tfOpName": "_FusedMatMul",
"category": "matrices",
"inputs": [
{
"start": 0,
"name": "a",
"type": "tensor"
},
{
"start": 1,
"name": "b",
"type": "tensor"
},
{
"start": 2,
"end": 0,
"name": "args",
"type": "tensors"
}
],
"attrs": [
{
"tfName": "num_args",
"name": "numArgs",
"type": "number"
},
{
"tfName": "fused_ops",
"name": "fusedOps",
"type": "string[]",
"defaultValue": []
},
{
"tfName": "epsilon",
"name": "epsilon",
"type": "number",
"defaultValue": 0.0001
},
{
"tfName": "transpose_a",
"name": "transposeA",
"type": "bool",
"defaultValue": false
},
{
"tfName": "transpose_b",
"name": "transposeB",
"type": "bool",
"defaultValue": false
},
{
"tfName": "T",
"name": "dtype",
"type": "dtype",
"notSupported": true
}
]
},
{
"tfOpName": "MatMul",
"category": "matrices",
Expand Down
35 changes: 35 additions & 0 deletions tfjs-converter/src/operations/executors/matrices_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,46 @@ export let executeOp: InternalOpExecutor = (node: Node,
getParamValue('b', node, tensorMap, context) as tfc.Tensor2D,
getParamValue('transposeA', node, tensorMap, context) as boolean,
getParamValue('transposeB', node, tensorMap, context) as boolean)];

case 'Transpose':
return [tfc.transpose(
getParamValue('x', node, tensorMap, context) as tfc.Tensor,
getParamValue('perm', node, tensorMap, context) as number[])];

case '_FusedMatMul':
const [extraOp, activationFunc] =
(getParamValue('fusedOps', node, tensorMap, context) as string[]);

const isBiasAdd = extraOp === 'biasadd';
const isPrelu = activationFunc === 'prelu';

const numArgs =
(getParamValue('numArgs', node, tensorMap, context) as number);
if (isBiasAdd) {
if (isPrelu && numArgs !== 2) {
throw new Error(
'Fused MatMul with BiasAdd and Prelu must have two ' +
'extra arguments: bias and alpha.');
}
if (!isPrelu && numArgs !== 1) {
throw new Error(
'Fused MatMul with BiasAdd must have one extra argument: bias.');
}
}
const [biasArg, preluArg] =
getParamValue('args', node, tensorMap, context) as tfc.Tensor[];
return [tfc.fused.matMul({
a: getParamValue('a', node, tensorMap, context) as tfc.Tensor2D,
b: getParamValue('b', node, tensorMap, context) as tfc.Tensor2D,
transposeA: getParamValue('transposeA', node, tensorMap, context) as
boolean,
transposeB: getParamValue('transposeB', node, tensorMap, context) as
boolean,
bias: biasArg,
activation: activationFunc as tfc.fused.Activation,
preluActivationWeights: preluArg
})];

default:
throw TypeError(`Node type ${node.op} is not implemented`);
}
Expand Down
Loading