From 21297e5c359735b042f9eefd80c755209c5a0e05 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 13:23:31 -0800 Subject: [PATCH 1/6] support fused matmul --- .../tensorflowjs/converters/fuse_prelu.py | 17 +++--- .../converters/fuse_prelu_test.py | 55 ++++++++++++++++- .../converters/graph_rewrite_util.py | 1 + .../tf_saved_model_conversion_v2.py | 8 +-- .../tf_saved_model_conversion_v2_test.py | 59 +++++++++++++++++++ .../python/tensorflowjs/op_list/matrices.json | 53 +++++++++++++++++ .../operations/executors/matrices_executor.ts | 35 +++++++++++ .../executors/matrices_executor_test.ts | 26 +++++++- .../src/operations/op_list/matrices.ts | 35 +++++++++++ 9 files changed, 274 insertions(+), 15 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py index 9d3adc515e3..76f2da98d9c 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py @@ -187,19 +187,20 @@ def fuse_prelu_with_fused_conv2d(input_graph_def): if node.op != 'Prelu': continue - fused_conv_op = graph_rewrite_util.node_from_map( + fused_op = graph_rewrite_util.node_from_map( input_node_map, node.input[0]) - if (not fused_conv_op or - (fused_conv_op.op != '_FusedConv2D' - and fused_conv_op.op != 'FusedDepthwiseConv2dNative') or - len(fused_conv_op.attr['fused_ops'].list.s) > 1): + if (not fused_op or + (fused_op.op != '_FusedConv2D' + and fused_op.op != '_FusedMatMul' + and fused_op.op != 'FusedDepthwiseConv2dNative') or + len(fused_op.attr['fused_ops'].list.s) > 1): continue alpha_tensor_name = node.input[1] - fused_conv_op.input.extend([alpha_tensor_name]) - fused_conv_op.attr['fused_ops'].list.s.extend([b'Prelu']) - fused_conv_op.attr['num_args'].i = fused_conv_op.attr['num_args'].i + 1 + fused_op.input.extend([alpha_tensor_name]) + fused_op.attr['fused_ops'].list.s.extend([b'Prelu']) + fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1 node.op = 'Identity' node.input[:] = [node.input[0]] nodes_to_skip[node.name] = True diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py index 3925e92f0d0..85f41c70129 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py @@ -21,6 +21,10 @@ import tensorflow as tf from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import variables +from tensorflow.python.training.tracking import tracking from tensorflowjs.converters import fuse_depthwise_conv2d from tensorflowjs.converters import fuse_prelu @@ -135,7 +139,55 @@ def execute_model(tensor): self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu']) self.assertEqual(conv2d_op.attr['num_args'].i, 2) - def testFusePreluWithDepthwiseConv2d(self): + def testFusePreluWithMatMul(self): + layers = [ + tf.keras.layers.Dense( + 2, use_bias=True, + kernel_initializer=tf.initializers.constant(0.25), + bias_initializer=tf.initializers.constant(0.25)), + tf.keras.layers.PReLU() + ] + model = tf.keras.Sequential(layers) + tf.keras.backend.set_learning_phase(0) + input_tensor = tf.constant([1.0, 1.0], shape=[1, 2]) + + @tf.function + def execute_model(tensor): + return model(tensor) + + graph = tf_saved_model_conversion_v2._freeze_saved_model_v2( + execute_model.get_concrete_function(input_tensor)) + graph_def = graph.as_graph_def() + for node in graph_def.node: + if node.op == 'MatMul': + node.device = "/CPU:0" + + config = config_pb2.ConfigProto() + rewriter_config = config.graph_options.rewrite_options + rewriter_config.optimizers[:] = [ + 'pruning', 'constfold', 'arithmetic', 'dependency', 'pruning', 'remap', + 'constfold', 'arithmetic', 'dependency' + ] + + for output in ['Identity']: + graph.add_to_collection('train_op', graph.get_operation_by_name(output)) + + signature = meta_graph_pb2.SignatureDef() + graph_def = tf_saved_model_conversion_v2._run_grappler( + config, graph_def, graph, signature) + graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def) + optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def) + + matmul_op = None + for node in optimized_graph_def.node: + self.assertNotEqual("Prelu", node.op) + if node.op == '_FusedMatMul': + matmul_op = node + self.assertNotEqual(matmul_op, None) + self.assertEqual(matmul_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu']) + self.assertEqual(matmul_op.attr['num_args'].i, 2) + +def testFusePreluWithDepthwiseConv2d(self): layers = [ tf.keras.layers.DepthwiseConv2D( 1, bias_initializer=tf.initializers.constant(0.25)), @@ -183,6 +235,5 @@ def execute_model(tensor): self.assertNotEqual(conv2d_op, None) self.assertEqual(conv2d_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu']) self.assertEqual(conv2d_op.attr['num_args'].i, 2) - if __name__ == '__main__': tf.test.main() diff --git a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py index 15531fab719..d1937f0b054 100644 --- a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py +++ b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py @@ -20,6 +20,7 @@ # Custom op name for fused depthwise conv2d FUSED_DEPTHWISE_CONV2D = 'FusedDepthwiseConv2dNative' +FUSED_MATMUL = '_FusedMatMul' def node_from_map(node_map, name): """Pulls a node def from a dictionary for a given name. diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 012e6e37f32..dcc485773e3 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -156,13 +156,13 @@ def optimize_graph(graph, signature_def, output_graph, # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) - # set the device to CPU for all Conv2d nodes, since grappler remap optimizer - # only support FusedConv2D for CPU. + # set the device to CPU for all Conv2d and MatMul nodes, since grappler + # remap optimizer only support FusedConv2D and FusedMatMul for CPU. for node in optimized_graph.node: - if node.op == 'Conv2D': + if node.op == 'Conv2D' or node.op == 'MatMul': node.device = '/device:CPU:0' - # rerun grappler to fuse conv2d + # rerun grappler to fuse conv2d/matmul config.graph_options.rewrite_options.optimizers[:] = [ 'remap', 'constfold', 'arithmetic', 'dependency' diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index 7710ccc6056..3bd844e7436 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -191,6 +191,18 @@ def _create_saved_model(self): save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) save(root, save_dir, to_save) + def _create_saved_model_with_fusable_matmul(self): + """Test a fusable matmul model with functions to make sure functions are inlined.""" + input_data = constant_op.constant(1., shape=[1, 1]) + bias_data = constant_op.constant(1., shape=[1]) + root = tracking.AutoTrackable() + root.v2 = variables.Variable([[2.]]) + root.f = def_function.function(lambda x: tf.nn.relu(tf.nn.bias_add(tf.matmul(x, root.v2), bias_data))) + to_save = root.f.get_concrete_function(input_data) + + save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + save(root, save_dir, to_save) + def _create_saved_model_with_control_flow(self): """Test a basic model with control flow to inlined.""" @tf.function @@ -440,6 +452,53 @@ def test_convert_saved_model_with_fused_conv2d(self): glob.glob( os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_saved_model_with_fused_matmul(self): + self._create_saved_model_with_fusable_matmul() + tf_saved_model_conversion_v2.convert_tf_saved_model( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR), + os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + ) + + tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) + # Check model.json and weights manifest. + with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f: + model_json = json.load(f) + self.assertTrue(model_json['modelTopology']) + self.assertIsNot(model_json['modelTopology']['versions'], None) + signature = model_json['userDefinedMetadata']['signature'] + self.assertIsNot(signature, None) + self.assertIsNot(signature['inputs'], None) + self.assertIsNot(signature['outputs'], None) + + nodes = model_json['modelTopology']['node'] + fusedOp = None + for node in nodes: + self.assertTrue(not 'MatMul' == node['op']) + self.assertTrue(not 'Relu' in node['op']) + self.assertTrue(not 'BiasAdd' in node['op']) + if node['op'] == graph_rewrite_util.FUSED_MATMUL: + fusedOp = node + self.assertTrue(fusedOp is not None) + self.assertIsNot(fusedOp['attr']['transpose_a'], None) + self.assertIsNot(fusedOp['attr']['transpose_b'], None) + self.assertEqual( + base64.b64decode(fusedOp['attr']['fused_ops']['list']['s'][0]), + b'BiasAdd') + self.assertEqual( + base64.b64decode(fusedOp['attr']['fused_ops']['list']['s'][1]), + b'Relu') + + # Check meta-data in the artifact JSON. + self.assertEqual(model_json['format'], 'graph-model') + self.assertEqual( + model_json['convertedBy'], + 'TensorFlow.js Converter v%s' % version.version) + self.assertEqual(model_json['generatedBy'], + tf.__version__) + self.assertTrue( + glob.glob( + os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*'))) + def test_convert_saved_model_with_fused_depthwise_conv2d(self): self._create_saved_model_with_fusable_depthwise_conv2d() tf_saved_model_conversion_v2.convert_tf_saved_model( diff --git a/tfjs-converter/python/tensorflowjs/op_list/matrices.json b/tfjs-converter/python/tensorflowjs/op_list/matrices.json index 62cd03398ed..fa382556301 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/matrices.json +++ b/tfjs-converter/python/tensorflowjs/op_list/matrices.json @@ -1,4 +1,57 @@ [ + { + "tfOpName": "_FusedMatMul", + "category": "matrices", + "inputs": [ + { + "start": 0, + "name": "a", + "type": "tensor" + }, + { + "start": 1, + "name": "b", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "num_args", + "name": "numArgs", + "type": "number" + }, + { + "tfName": "fused_ops", + "name": "fusedOps", + "type": "string[]", + "defaultValue": [] + }, + { + "tfName": "epsilon", + "name": "epsilon", + "type": "number", + "defaultValue": 0.0001 + }, + { + "tfName": "transpose_a", + "name": "transposeA", + "type": "bool", + "defaultValue": false + }, + { + "tfName": "transpose_b", + "name": "transposeB", + "type": "bool", + "defaultValue": false + }, + { + "tfName": "T", + "name": "dtype", + "type": "dtype", + "notSupported": true + } + ] + }, { "tfOpName": "MatMul", "category": "matrices", diff --git a/tfjs-converter/src/operations/executors/matrices_executor.ts b/tfjs-converter/src/operations/executors/matrices_executor.ts index e12cd11c8e3..ba8289a0684 100644 --- a/tfjs-converter/src/operations/executors/matrices_executor.ts +++ b/tfjs-converter/src/operations/executors/matrices_executor.ts @@ -36,11 +36,46 @@ export let executeOp: InternalOpExecutor = (node: Node, getParamValue('b', node, tensorMap, context) as tfc.Tensor2D, getParamValue('transposeA', node, tensorMap, context) as boolean, getParamValue('transposeB', node, tensorMap, context) as boolean)]; + case 'Transpose': return [tfc.transpose( getParamValue('x', node, tensorMap, context) as tfc.Tensor, getParamValue('perm', node, tensorMap, context) as number[])]; + case '_FusedMatMul': + const [extraOp, activationFunc] = + (getParamValue('fusedOps', node, tensorMap, context) as string[]); + + const isBiasAdd = extraOp === 'biasadd'; + const isPrelu = activationFunc === 'prelu'; + + const numArgs = + (getParamValue('numArgs', node, tensorMap, context) as number); + if (isBiasAdd) { + if (isPrelu && numArgs !== 2) { + throw new Error( + 'Fused MatMul with BiasAdd and Prelu must have two ' + + 'extra arguments: bias and alpha.'); + } + if (!isPrelu && numArgs !== 1) { + throw new Error( + 'Fused MatMul with BiasAdd must have one extra argument: bias.'); + } + } + const [biasArg, preluArg] = + getParamValue('args', node, tensorMap, context) as tfc.Tensor[]; + return [tfc.fused.matMul({ + a: getParamValue('a', node, tensorMap, context) as tfc.Tensor2D, + b: getParamValue('b', node, tensorMap, context) as tfc.Tensor2D, + transposeA: getParamValue('transposeA', node, tensorMap, context) as + boolean, + transposeB: getParamValue('transposeB', node, tensorMap, context) as + boolean, + bias: biasArg, + activation: activationFunc as tfc.fused.Activation, + preluActivationWeights: preluArg + })]; + default: throw TypeError(`Node type ${node.op} is not implemented`); } diff --git a/tfjs-converter/src/operations/executors/matrices_executor_test.ts b/tfjs-converter/src/operations/executors/matrices_executor_test.ts index 81bee35bd83..c321d9d0f32 100644 --- a/tfjs-converter/src/operations/executors/matrices_executor_test.ts +++ b/tfjs-converter/src/operations/executors/matrices_executor_test.ts @@ -20,7 +20,7 @@ import {ExecutionContext} from '../../executor/execution_context'; import {Node} from '../types'; import {executeOp} from './matrices_executor'; -import {createBoolAttr, createNumericArrayAttr, createTensorAttr} from './test_helper'; +import {createBoolAttr, createNumberAttr, createNumericArrayAttr, createStrArrayAttr, createTensorAttr, createTensorsAttr} from './test_helper'; describe('matrices', () => { let node: Node; @@ -54,6 +54,30 @@ describe('matrices', () => { .toHaveBeenCalledWith(input1[0], input2[0], true, false); }); }); + describe('_FusedMatMul', () => { + it('should call tfc.fused.matMul', () => { + spyOn(tfc.fused, 'matMul'); + node.op = '_FusedMatMul'; + node.inputParams['args'] = createTensorsAttr(2, 0); + node.attrParams['fusedOps'] = createStrArrayAttr(['biasadd', 'relu']); + node.attrParams['numArgs'] = createNumberAttr(1); + node.attrParams.transposeA = createBoolAttr(true); + node.attrParams.transposeB = createBoolAttr(false); + const input3 = [tfc.scalar(3.0)]; + node.inputNames = ['input1', 'input2', 'input3']; + executeOp(node, {input1, input2, input3}, context); + + expect(tfc.fused.matMul).toHaveBeenCalledWith({ + a: input1[0], + b: input2[0], + transposeA: true, + transposeB: false, + bias: input3[0], + activation: 'relu', + preluActivationWeights: undefined + }); + }); + }); describe('BatchMatMul', () => { it('should call tfc.matMul', () => { spyOn(tfc, 'matMul'); diff --git a/tfjs-converter/src/operations/op_list/matrices.ts b/tfjs-converter/src/operations/op_list/matrices.ts index 4e1f5a828e0..3b3bad2fb09 100644 --- a/tfjs-converter/src/operations/op_list/matrices.ts +++ b/tfjs-converter/src/operations/op_list/matrices.ts @@ -18,6 +18,41 @@ import {OpMapper} from '../types'; export const json: OpMapper[] = [ + { + 'tfOpName': '_FusedMatMul', + 'category': 'matrices', + 'inputs': [ + {'start': 0, 'name': 'a', 'type': 'tensor'}, + {'start': 1, 'name': 'b', 'type': 'tensor'}, + ], + 'attrs': [ + {'tfName': 'num_args', 'name': 'numArgs', 'type': 'number'}, { + 'tfName': 'fused_ops', + 'name': 'fusedOps', + 'type': 'string[]', + 'defaultValue': [] + }, + { + 'tfName': 'epsilon', + 'name': 'epsilon', + 'type': 'number', + 'defaultValue': 0.0001 + }, + { + 'tfName': 'transpose_a', + 'name': 'transposeA', + 'type': 'bool', + 'defaultValue': false + }, + { + 'tfName': 'transpose_b', + 'name': 'transposeB', + 'type': 'bool', + 'defaultValue': false + }, + {'tfName': 'T', 'name': 'dtype', 'type': 'dtype', 'notSupported': true} + ] + }, { 'tfOpName': 'MatMul', 'category': 'matrices', From 8f0ff7db21120b3fb62efae3b3167cc6f026f54a Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 13:31:22 -0800 Subject: [PATCH 2/6] fix pylint errors --- .../python/tensorflowjs/converters/fuse_prelu_test.py | 6 +----- .../converters/tf_saved_model_conversion_v2.py | 2 +- .../converters/tf_saved_model_conversion_v2_test.py | 8 +++++--- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py index 85f41c70129..ef53b90a38d 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py @@ -21,10 +21,6 @@ import tensorflow as tf from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 -from tensorflow.python.eager import def_function -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import variables -from tensorflow.python.training.tracking import tracking from tensorflowjs.converters import fuse_depthwise_conv2d from tensorflowjs.converters import fuse_prelu @@ -187,7 +183,7 @@ def execute_model(tensor): self.assertEqual(matmul_op.attr['fused_ops'].list.s, [b'BiasAdd', b'Prelu']) self.assertEqual(matmul_op.attr['num_args'].i, 2) -def testFusePreluWithDepthwiseConv2d(self): + def testFusePreluWithDepthwiseConv2d(self): layers = [ tf.keras.layers.DepthwiseConv2D( 1, bias_initializer=tf.initializers.constant(0.25)), diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index dcc485773e3..17041088377 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -156,7 +156,7 @@ def optimize_graph(graph, signature_def, output_graph, # batch norm folding optimized_graph = fold_batch_norms.fold_batch_norms(optimized_graph) - # set the device to CPU for all Conv2d and MatMul nodes, since grappler + # set the device to CPU for all Conv2d and MatMul nodes, since grappler # remap optimizer only support FusedConv2D and FusedMatMul for CPU. for node in optimized_graph.node: if node.op == 'Conv2D' or node.op == 'MatMul': diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py index 3bd844e7436..b10af645479 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py @@ -192,12 +192,14 @@ def _create_saved_model(self): save(root, save_dir, to_save) def _create_saved_model_with_fusable_matmul(self): - """Test a fusable matmul model with functions to make sure functions are inlined.""" + """Test a fusable matmul model.""" input_data = constant_op.constant(1., shape=[1, 1]) bias_data = constant_op.constant(1., shape=[1]) root = tracking.AutoTrackable() root.v2 = variables.Variable([[2.]]) - root.f = def_function.function(lambda x: tf.nn.relu(tf.nn.bias_add(tf.matmul(x, root.v2), bias_data))) + root.f = def_function.function( + lambda x: tf.nn.relu(tf.nn.bias_add(tf.matmul(x, root.v2), + bias_data))) to_save = root.f.get_concrete_function(input_data) save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR) @@ -473,7 +475,7 @@ def test_convert_saved_model_with_fused_matmul(self): nodes = model_json['modelTopology']['node'] fusedOp = None for node in nodes: - self.assertTrue(not 'MatMul' == node['op']) + self.assertTrue(node['op'] != 'MatMul') self.assertTrue(not 'Relu' in node['op']) self.assertTrue(not 'BiasAdd' in node['op']) if node['op'] == graph_rewrite_util.FUSED_MATMUL: From 33d384a77aae93dc539fc249adf362b10c7c4fe0 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 15:45:32 -0800 Subject: [PATCH 3/6] addressed the comments --- tfjs-converter/.vscode/.ropeproject/config.py | 114 ++++++++++++++++++ tfjs-converter/.vscode/.ropeproject/objectdb | Bin 0 -> 6 bytes .../converters/graph_rewrite_util.py | 1 + .../executors/matrices_executor_test.ts | 23 ++++ .../src/operations/op_list/matrices.ts | 1 + 5 files changed, 139 insertions(+) create mode 100644 tfjs-converter/.vscode/.ropeproject/config.py create mode 100644 tfjs-converter/.vscode/.ropeproject/objectdb diff --git a/tfjs-converter/.vscode/.ropeproject/config.py b/tfjs-converter/.vscode/.ropeproject/config.py new file mode 100644 index 00000000000..dee2d1ae9a6 --- /dev/null +++ b/tfjs-converter/.vscode/.ropeproject/config.py @@ -0,0 +1,114 @@ +# The default ``config.py`` +# flake8: noqa + + +def set_prefs(prefs): + """This function is called before opening the project""" + + # Specify which files and folders to ignore in the project. + # Changes to ignored resources are not added to the history and + # VCSs. Also they are not returned in `Project.get_files()`. + # Note that ``?`` and ``*`` match all characters but slashes. + # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' + # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' + # '.svn': matches 'pkg/.svn' and all of its children + # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' + # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' + prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', + '.hg', '.svn', '_svn', '.git', '.tox'] + + # Specifies which files should be considered python files. It is + # useful when you have scripts inside your project. Only files + # ending with ``.py`` are considered to be python files by + # default. + # prefs['python_files'] = ['*.py'] + + # Custom source folders: By default rope searches the project + # for finding source folders (folders that should be searched + # for finding modules). You can add paths to that list. Note + # that rope guesses project source folders correctly most of the + # time; use this if you have any problems. + # The folders should be relative to project root and use '/' for + # separating folders regardless of the platform rope is running on. + # 'src/my_source_folder' for instance. + # prefs.add('source_folders', 'src') + + # You can extend python path for looking up modules + # prefs.add('python_path', '~/python/') + + # Should rope save object information or not. + prefs['save_objectdb'] = True + prefs['compress_objectdb'] = False + + # If `True`, rope analyzes each module when it is being saved. + prefs['automatic_soa'] = True + # The depth of calls to follow in static object analysis + prefs['soa_followed_calls'] = 0 + + # If `False` when running modules or unit tests "dynamic object + # analysis" is turned off. This makes them much faster. + prefs['perform_doa'] = True + + # Rope can check the validity of its object DB when running. + prefs['validate_objectdb'] = True + + # How many undos to hold? + prefs['max_history_items'] = 32 + + # Shows whether to save history across sessions. + prefs['save_history'] = True + prefs['compress_history'] = False + + # Set the number spaces used for indenting. According to + # :PEP:`8`, it is best to use 4 spaces. Since most of rope's + # unit-tests use 4 spaces it is more reliable, too. + prefs['indent_size'] = 4 + + # Builtin and c-extension modules that are allowed to be imported + # and inspected by rope. + prefs['extension_modules'] = [] + + # Add all standard c-extensions to extension_modules list. + prefs['import_dynload_stdmods'] = True + + # If `True` modules with syntax errors are considered to be empty. + # The default value is `False`; When `False` syntax errors raise + # `rope.base.exceptions.ModuleSyntaxError` exception. + prefs['ignore_syntax_errors'] = False + + # If `True`, rope ignores unresolvable imports. Otherwise, they + # appear in the importing namespace. + prefs['ignore_bad_imports'] = False + + # If `True`, rope will insert new module imports as + # `from import ` by default. + prefs['prefer_module_from_imports'] = False + + # If `True`, rope will transform a comma list of imports into + # multiple separate import statements when organizing + # imports. + prefs['split_imports'] = False + + # If `True`, rope will remove all top-level import statements and + # reinsert them at the top of the module when making changes. + prefs['pull_imports_to_top'] = True + + # If `True`, rope will sort imports alphabetically by module name instead + # of alphabetically by import statement, with from imports after normal + # imports. + prefs['sort_imports_alphabetically'] = False + + # Location of implementation of + # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general + # case, you don't have to change this value, unless you're an rope expert. + # Change this value to inject you own implementations of interfaces + # listed in module rope.base.oi.type_hinting.providers.interfaces + # For example, you can add you own providers for Django Models, or disable + # the search type-hinting in a class hierarchy, etc. + prefs['type_hinting_factory'] = ( + 'rope.base.oi.type_hinting.factory.default_type_hinting_factory') + + +def project_opened(project): + """This function is called after opening the project""" + # Do whatever you like here! diff --git a/tfjs-converter/.vscode/.ropeproject/objectdb b/tfjs-converter/.vscode/.ropeproject/objectdb new file mode 100644 index 0000000000000000000000000000000000000000..0a47446c0ad231c193bdd44ff327ba2ab28bf3d8 GIT binary patch literal 6 NcmZo*sx4&D0{{kv0iOT> literal 0 HcmV?d00001 diff --git a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py index d1937f0b054..dbfd8b0e30d 100644 --- a/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py +++ b/tfjs-converter/python/tensorflowjs/converters/graph_rewrite_util.py @@ -20,6 +20,7 @@ # Custom op name for fused depthwise conv2d FUSED_DEPTHWISE_CONV2D = 'FusedDepthwiseConv2dNative' +# The grappler op name for fused MatMul which starts with '_' FUSED_MATMUL = '_FusedMatMul' def node_from_map(node_map, name): diff --git a/tfjs-converter/src/operations/executors/matrices_executor_test.ts b/tfjs-converter/src/operations/executors/matrices_executor_test.ts index c321d9d0f32..df2912474e8 100644 --- a/tfjs-converter/src/operations/executors/matrices_executor_test.ts +++ b/tfjs-converter/src/operations/executors/matrices_executor_test.ts @@ -77,6 +77,29 @@ describe('matrices', () => { preluActivationWeights: undefined }); }); + it('should call tfc.fused.matMul - prelu activation', () => { + spyOn(tfc.fused, 'matMul'); + node.op = '_FusedMatMul'; + node.inputParams['args'] = createTensorsAttr(2, 0); + node.attrParams['fusedOps'] = createStrArrayAttr(['biasadd', 'prelu']); + node.attrParams['numArgs'] = createNumberAttr(2); + node.attrParams.transposeA = createBoolAttr(true); + node.attrParams.transposeB = createBoolAttr(false); + const input3 = [tfc.scalar(3.0)]; + const input4 = [tfc.scalar(4.0)]; + node.inputNames = ['input1', 'input2', 'input3', 'input4']; + executeOp(node, {input1, input2, input3, input4}, context); + + expect(tfc.fused.matMul).toHaveBeenCalledWith({ + a: input1[0], + b: input2[0], + transposeA: true, + transposeB: false, + bias: input3[0], + activation: 'prelu', + preluActivationWeights: input4[0] + }); + }); }); describe('BatchMatMul', () => { it('should call tfc.matMul', () => { diff --git a/tfjs-converter/src/operations/op_list/matrices.ts b/tfjs-converter/src/operations/op_list/matrices.ts index 3b3bad2fb09..ae6a1bbb405 100644 --- a/tfjs-converter/src/operations/op_list/matrices.ts +++ b/tfjs-converter/src/operations/op_list/matrices.ts @@ -24,6 +24,7 @@ export const json: OpMapper[] = [ 'inputs': [ {'start': 0, 'name': 'a', 'type': 'tensor'}, {'start': 1, 'name': 'b', 'type': 'tensor'}, + {'start': 2, end: 0, 'name': 'args', 'type': 'tensors'}, ], 'attrs': [ {'tfName': 'num_args', 'name': 'numArgs', 'type': 'number'}, { From 83fc37351b6c364331276a1e1deead0f529c5ea8 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 15:45:51 -0800 Subject: [PATCH 4/6] addressed the comments --- tfjs-converter/python/tensorflowjs/op_list/matrices.json | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tfjs-converter/python/tensorflowjs/op_list/matrices.json b/tfjs-converter/python/tensorflowjs/op_list/matrices.json index fa382556301..69c42b340e3 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/matrices.json +++ b/tfjs-converter/python/tensorflowjs/op_list/matrices.json @@ -12,6 +12,12 @@ "start": 1, "name": "b", "type": "tensor" + }, + { + "start": 2, + "end": 0, + "name": "args", + "type": "tensors" } ], "attrs": [ From a2b23aa5deffa08f6c63fc3a756c1b3d17fb39f8 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 16:03:37 -0800 Subject: [PATCH 5/6] missing the method name change --- .../python/tensorflowjs/converters/fuse_prelu.py | 2 +- .../python/tensorflowjs/converters/fuse_prelu_test.py | 9 ++++++--- .../converters/tf_saved_model_conversion_v2.py | 3 ++- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py index 76f2da98d9c..b9f654736d3 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py @@ -157,7 +157,7 @@ def _create_alpha_node(neg_alpha_op, updated_alpha): alpha_value, alpha_value.dtype.type, alpha_value.shape))) updated_alpha.append(neg_alpha_op.name) -def fuse_prelu_with_fused_conv2d(input_graph_def): +def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def): """Tensorflow does not support Prelu op, and the grappler remap optimizer will not fuse the prelu op with _FusedConv2D op. This method searches for the pattern and fuse the (_FusedConv2D||FusedDepthwiseConv2dNative + Prelu) diff --git a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py index ef53b90a38d..08f856b8476 100644 --- a/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py +++ b/tfjs-converter/python/tensorflowjs/converters/fuse_prelu_test.py @@ -124,7 +124,8 @@ def execute_model(tensor): config, graph_def, graph, signature) graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def) - optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def) + optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul( + graph_def) conv2d_op = None for node in optimized_graph_def.node: @@ -172,7 +173,8 @@ def execute_model(tensor): graph_def = tf_saved_model_conversion_v2._run_grappler( config, graph_def, graph, signature) graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def) - optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def) + optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul( + graph_def) matmul_op = None for node in optimized_graph_def.node: @@ -221,7 +223,8 @@ def execute_model(tensor): graph_def = fuse_prelu.fuse_ops_for_prelu(graph_def) graph_def = fuse_depthwise_conv2d.fuse_depthwise_conv2d(graph_def) - optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d(graph_def) + optimized_graph_def = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul( + graph_def) conv2d_op = None for node in optimized_graph_def.node: diff --git a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py index 17041088377..3626a6ed022 100644 --- a/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py +++ b/tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py @@ -181,7 +181,8 @@ def optimize_graph(graph, signature_def, output_graph, # Since the grappler remap optimizer doe snot support prelu as the activation # function for _FusedConv2D op, we are doing it manually here. - optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d(optimized_graph) + optimized_graph = fuse_prelu.fuse_prelu_with_fused_conv2d_or_matmul( + optimized_graph) unsupported = validate(optimized_graph.node, skip_op_check, strip_debug_ops) From 92b19db56f6b0ca8f35697fa5a420b0dbea9387b Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 2 Dec 2019 16:07:31 -0800 Subject: [PATCH 6/6] remove vscode config changes --- tfjs-converter/.vscode/.ropeproject/config.py | 114 ------------------ tfjs-converter/.vscode/.ropeproject/objectdb | Bin 6 -> 0 bytes 2 files changed, 114 deletions(-) delete mode 100644 tfjs-converter/.vscode/.ropeproject/config.py delete mode 100644 tfjs-converter/.vscode/.ropeproject/objectdb diff --git a/tfjs-converter/.vscode/.ropeproject/config.py b/tfjs-converter/.vscode/.ropeproject/config.py deleted file mode 100644 index dee2d1ae9a6..00000000000 --- a/tfjs-converter/.vscode/.ropeproject/config.py +++ /dev/null @@ -1,114 +0,0 @@ -# The default ``config.py`` -# flake8: noqa - - -def set_prefs(prefs): - """This function is called before opening the project""" - - # Specify which files and folders to ignore in the project. - # Changes to ignored resources are not added to the history and - # VCSs. Also they are not returned in `Project.get_files()`. - # Note that ``?`` and ``*`` match all characters but slashes. - # '*.pyc': matches 'test.pyc' and 'pkg/test.pyc' - # 'mod*.pyc': matches 'test/mod1.pyc' but not 'mod/1.pyc' - # '.svn': matches 'pkg/.svn' and all of its children - # 'build/*.o': matches 'build/lib.o' but not 'build/sub/lib.o' - # 'build//*.o': matches 'build/lib.o' and 'build/sub/lib.o' - prefs['ignored_resources'] = ['*.pyc', '*~', '.ropeproject', - '.hg', '.svn', '_svn', '.git', '.tox'] - - # Specifies which files should be considered python files. It is - # useful when you have scripts inside your project. Only files - # ending with ``.py`` are considered to be python files by - # default. - # prefs['python_files'] = ['*.py'] - - # Custom source folders: By default rope searches the project - # for finding source folders (folders that should be searched - # for finding modules). You can add paths to that list. Note - # that rope guesses project source folders correctly most of the - # time; use this if you have any problems. - # The folders should be relative to project root and use '/' for - # separating folders regardless of the platform rope is running on. - # 'src/my_source_folder' for instance. - # prefs.add('source_folders', 'src') - - # You can extend python path for looking up modules - # prefs.add('python_path', '~/python/') - - # Should rope save object information or not. - prefs['save_objectdb'] = True - prefs['compress_objectdb'] = False - - # If `True`, rope analyzes each module when it is being saved. - prefs['automatic_soa'] = True - # The depth of calls to follow in static object analysis - prefs['soa_followed_calls'] = 0 - - # If `False` when running modules or unit tests "dynamic object - # analysis" is turned off. This makes them much faster. - prefs['perform_doa'] = True - - # Rope can check the validity of its object DB when running. - prefs['validate_objectdb'] = True - - # How many undos to hold? - prefs['max_history_items'] = 32 - - # Shows whether to save history across sessions. - prefs['save_history'] = True - prefs['compress_history'] = False - - # Set the number spaces used for indenting. According to - # :PEP:`8`, it is best to use 4 spaces. Since most of rope's - # unit-tests use 4 spaces it is more reliable, too. - prefs['indent_size'] = 4 - - # Builtin and c-extension modules that are allowed to be imported - # and inspected by rope. - prefs['extension_modules'] = [] - - # Add all standard c-extensions to extension_modules list. - prefs['import_dynload_stdmods'] = True - - # If `True` modules with syntax errors are considered to be empty. - # The default value is `False`; When `False` syntax errors raise - # `rope.base.exceptions.ModuleSyntaxError` exception. - prefs['ignore_syntax_errors'] = False - - # If `True`, rope ignores unresolvable imports. Otherwise, they - # appear in the importing namespace. - prefs['ignore_bad_imports'] = False - - # If `True`, rope will insert new module imports as - # `from import ` by default. - prefs['prefer_module_from_imports'] = False - - # If `True`, rope will transform a comma list of imports into - # multiple separate import statements when organizing - # imports. - prefs['split_imports'] = False - - # If `True`, rope will remove all top-level import statements and - # reinsert them at the top of the module when making changes. - prefs['pull_imports_to_top'] = True - - # If `True`, rope will sort imports alphabetically by module name instead - # of alphabetically by import statement, with from imports after normal - # imports. - prefs['sort_imports_alphabetically'] = False - - # Location of implementation of - # rope.base.oi.type_hinting.interfaces.ITypeHintingFactory In general - # case, you don't have to change this value, unless you're an rope expert. - # Change this value to inject you own implementations of interfaces - # listed in module rope.base.oi.type_hinting.providers.interfaces - # For example, you can add you own providers for Django Models, or disable - # the search type-hinting in a class hierarchy, etc. - prefs['type_hinting_factory'] = ( - 'rope.base.oi.type_hinting.factory.default_type_hinting_factory') - - -def project_opened(project): - """This function is called after opening the project""" - # Do whatever you like here! diff --git a/tfjs-converter/.vscode/.ropeproject/objectdb b/tfjs-converter/.vscode/.ropeproject/objectdb deleted file mode 100644 index 0a47446c0ad231c193bdd44ff327ba2ab28bf3d8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6 NcmZo*sx4&D0{{kv0iOT>