Skip to content

Conversation

@pyu10055
Copy link
Collaborator

@pyu10055 pyu10055 commented Nov 27, 2019

This PR fuses the DepthwiseConv2dNative + BiasAdd + (Activation) => FusedDepthwiseConv2dNative op.
This speeds on models that use similar structure as mobilenet.

javascript executor changes will be in a follow up PR.
To see the logs from the Cloud Build CI, please join either our discussion or announcement mailing list.


This change is Reviewable

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice to see this so quickly! Left a few comments.

Reviewed 9 of 9 files at r1.
Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan, @dsmilkov, and @pyu10055)


tfjs-converter/python/tensorflowjs/converters/common.py, line 124 at r1 (raw file):

  return node_name

def cleanup_graph_def(input_graph_def, nodes_to_skip, inputs_to_remove):

generic filenames like common.py do not provide context about their logic. Let's move the graph rewriting methods (like this cleanup) in a graph_cleanup.py


tfjs-converter/python/tensorflowjs/converters/common.py, line 133 at r1 (raw file):

    inputs_to_remove: List of nodes to be removed from inputs of all nodes.
  Returns:
    GraphDef that has been cleaned..

remove double period at the end


tfjs-converter/python/tensorflowjs/converters/common.py, line 143 at r1 (raw file):

    new_node.CopyFrom(node)
    for value in inputs_to_remove:
      if value.name in new_node.input:

remove the "if" since the "for-loop" after this line will do the search for you. Otherwise, you end up searching twice through new_node.input (once with the "in" operator and once with the "for-loop" below)


tfjs-converter/python/tensorflowjs/converters/common.py, line 146 at r1 (raw file):

        for i, input_node in enumerate(new_node.input):
          if input_node == value.name:
            print(value.input)

remove print statement


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 72 at r1 (raw file):

def _add_fused_contraction_node(contraction, bias_add, activation,
                                inputs_to_remove, nodes_to_skip):
  print("Fuse " + contraction.op + " with BiasAdd: " +

remove print


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 104 at r1 (raw file):

  Returns:
    Modified graph with Prelu ops generated, and modified weights.

doc out of sync, remove prelu


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 131 at r1 (raw file):

  Returns:
    Modified graph with Prelu ops generated, and modified weights.

doc out of sync, remove prelu


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 163 at r1 (raw file):

     we need to clean up the attributes for FusedDepthwiseConv2dNative op.
  Args:
    input_graph_def: A tf.Graph object to insert prelu function into.

doc out of sync, remove prelu


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 177 at r1 (raw file):

def register_fused_depthwise_conv2d_func(graph):
  """Register _FusedDepthwiseConv2dNative op with function def, this is needed
  for importing graph_def with unregistered op.

when will you need to import a graphdef that has _FusedDepthwiseConv2DNative? Shouldn't that op only exist temporarily during model conversion to model.json?


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 179 at r1 (raw file):

  for importing graph_def with unregistered op.
  Args:
    graph: A tf.Graph object to insert prelu function into.

doc out of sync, remove prelu


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 186 at r1 (raw file):

                  func_name=FUSED_DEPTHWISE_CONV2D)
  def fused_depthwise_conv2d_fn(*args):
    return tf.nn.depthwise_conv2d(

Does this function need to be the correct implementation of fused depthwise? If this function used? If yes, then you need to add the bias and the potential activation during forward mode. If not, leave a comment why this function's implementation doesn't matter/is not used.


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d_test.py, line 15 at r1 (raw file):

# limitations under the License.
# ==============================================================================
"""Unit tests for prelu op fusing."""

doc out of sync, remove prelu


tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py, line 188 at r1 (raw file):

    fused_conv_op = common.node_from_map(input_node_map, node.input[0])
    if (not fused_conv_op or fused_conv_op.op != '_FusedConv2D' or

generalize the prelu fusing to work with _FusedDepthwiseConv2D too and add unit tests.


tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py, line 227 at r1 (raw file):

  fuse_depthwise_conv2d.register_fused_depthwise_conv2d_func(graph)

  extraced_graph = fuse_depthwise_conv2d.extract_op_attributes(graph_def)

potential typo. Rename to fused_graph ?


tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py, line 502 at r1 (raw file):

    with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
      model_json = json.load(f)
    print(model_json)

remove print


tfjs-converter/python/tensorflowjs/op_list/convolution.json, line 462 at r1 (raw file):

  },
  {
    "tfOpName": "FusedDepthwiseConv2dNative",

If you changed the tfOpName to "doesnotmatter", will a unit test break? Also I think you need an underscore at the start of the name to be consistent with the tfOpName: "_FusedConv2D" above in this file.

Copy link
Collaborator Author

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 approvals obtained (waiting on @annxingyuan and @dsmilkov)


tfjs-converter/python/tensorflowjs/converters/common.py, line 124 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

generic filenames like common.py do not provide context about their logic. Let's move the graph rewriting methods (like this cleanup) in a graph_cleanup.py

moved, thanks for the suggestion.


tfjs-converter/python/tensorflowjs/converters/common.py, line 143 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

remove the "if" since the "for-loop" after this line will do the search for you. Otherwise, you end up searching twice through new_node.input (once with the "in" operator and once with the "for-loop" below)

Done.


tfjs-converter/python/tensorflowjs/converters/common.py, line 146 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

remove print statement

Done.


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 72 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

remove print

Done.


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 177 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

when will you need to import a graphdef that has _FusedDepthwiseConv2DNative? Shouldn't that op only exist temporarily during model conversion to model.json?

The op is named as FusedDepthwiseConv2dNative, since TF does not allow op name to start with '_', which is reserved for internal ops.


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 186 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Does this function need to be the correct implementation of fused depthwise? If this function used? If yes, then you need to add the bias and the potential activation during forward mode. If not, leave a comment why this function's implementation doesn't matter/is not used.

Done.


tfjs-converter/python/tensorflowjs/converters/fuse_prelu.py, line 188 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

generalize the prelu fusing to work with _FusedDepthwiseConv2D too and add unit tests.

added support for prelu fusing with fused depthwise conv2d


tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py, line 227 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

potential typo. Rename to fused_graph ?

renamed to extracted_graph


tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py, line 502 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

remove print

Done.


tfjs-converter/python/tensorflowjs/op_list/convolution.json, line 462 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

If you changed the tfOpName to "doesnotmatter", will a unit test break? Also I think you need an underscore at the start of the name to be consistent with the tfOpName: "_FusedConv2D" above in this file.

yes, it will fail the conversion test, since it triggers not supported op. as mentioned earlier we cannot use op name started with '_'

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Just a short followup about couple of tiny comments missed and two non-blocking discussions for my own understanding. LGTM otherwise!!

Reviewed 11 of 11 files at r2.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @pyu10055)


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 104 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

doc out of sync, remove prelu

missed the comment?


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 131 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

doc out of sync, remove prelu

missed the comment?


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 177 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

The op is named as FusedDepthwiseConv2dNative, since TF does not allow op name to start with '_', which is reserved for internal ops.

Thanks! I'm still curious about my original question. When will you need to import a graphdef with this op? In other words, why do you need to register it when we never really import a graphdef with that op? The rewritten graph lives temporarily.


tfjs-converter/python/tensorflowjs/op_list/convolution.json, line 462 at r1 (raw file):

Previously, pyu10055 (Ping Yu) wrote…

yes, it will fail the conversion test, since it triggers not supported op. as mentioned earlier we cannot use op name started with '_'

I think I'm still missing something. Why can _FusedConv2D have underscore but FusedDepthwiseConv2DNative not have underscore? What is the difference between these two ops?

Copy link
Collaborator Author

@pyu10055 pyu10055 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan and @dsmilkov)


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 104 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

missed the comment?

sorry, updated.


tfjs-converter/python/tensorflowjs/converters/fuse_depthwise_conv2d.py, line 177 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

Thanks! I'm still curious about my original question. When will you need to import a graphdef with this op? In other words, why do you need to register it when we never really import a graphdef with that op? The rewritten graph lives temporarily.

The op registration is need for the step to extract constant from the graph, which we need to load the graph and run eval with session.


tfjs-converter/python/tensorflowjs/op_list/convolution.json, line 462 at r1 (raw file):

Previously, dsmilkov (Daniel Smilkov) wrote…

I think I'm still missing something. Why can _FusedConv2D have underscore but FusedDepthwiseConv2DNative not have underscore? What is the difference between these two ops?

FusedConv2D is grappler op that is registered internally using c++, we are basically trying to create an empty op by using tf.Function, the '' is not allowed in this case.

Copy link
Contributor

@dsmilkov dsmilkov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Ping! Enjoy the holidays!

Reviewed 2 of 2 files at r3.
Reviewable status: :shipit: complete! 1 of 1 approvals obtained (waiting on @annxingyuan)

@pyu10055 pyu10055 merged commit b6f6774 into master Nov 27, 2019
@pyu10055 pyu10055 deleted the fuse_depthwise_conv2d branch November 27, 2019 20:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants