Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TFL] Enhance MoveBinaryOpBeforeReshape pattern to enable fusion of 2D bias. #47510

Conversation

WindQAQ
Copy link
Member

@WindQAQ WindQAQ commented Mar 2, 2021

Some uses of EinsumDense layer will create 2D or higher bias. For example,

layer = tf.keras.layers.MultiHeadAttention(num_heads=3, key_dim=5)
target = tf.keras.Input(shape=[8, 16], batch_size=1)
source = tf.keras.Input(shape=[4, 16], batch_size=1)
output_tensor = layer(target, source, return_attention_scores=False)
model = tf.keras.Model([target, source], output_tensor)

This PR reorders Reshape and BinaryOp to make 2D bias flatten and fusable to FullyConnected. In the case above, bias add can be fused into FullyConnected at location 0, 2 and 172.

Some uses of EinsumDense layer will create 2D or higher bias. For example,

```python
layer = tf.keras.layers.MultiHeadAttention(num_heads=3, key_dim=5)
target = tf.keras.Input(shape=[8, 16], batch_size=1)
source = tf.keras.Input(shape=[4, 16], batch_size=1)
output_tensor = layer(target, source, return_attention_scores=False)
model = tf.keras.Model([target, source], output_tensor)
```

This PR reorder Reshape and BinaryOp to make 2D bias flatten and fusable to FullyConnected.
@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Mar 2, 2021
@google-cla google-cla bot added the cla: yes label Mar 2, 2021
@@ -642,6 +642,34 @@ func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor<
// CHECK: return %[[rs2]]
}

// CHECK-LABEL: @ReorderReshapeAdd2DConst
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add test cases for reordering and fusing?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. I follow the test cases of FuseFullyConnectedReshapeAddConst*. Let me know if that's applicable. Thank you!

@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Mar 2, 2021
@abattery
Copy link
Contributor

abattery commented Mar 2, 2021

Thanks for the contributions!

@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 2, 2021
@abattery
Copy link
Contributor

abattery commented Mar 3, 2021

Please add additional checks in the new patterns to make sure that the newly generated binary ops will have <=4D inputs since TFLite binary kernels support broadcasting up to 4D inputs.

@abattery
Copy link
Contributor

abattery commented Mar 3, 2021

Also it would be nice to add a test case for >4D input cases.

@abattery
Copy link
Contributor

abattery commented Mar 3, 2021

Also can we limit the reorderings only when the FullyConnected op is appeared next to it?

@WindQAQ
Copy link
Member Author

WindQAQ commented Mar 3, 2021

Also can we limit the reorderings only when the FullyConnected op is appeared next to it?

Do you mean this pattern only, or all other reordering pattern?

@abattery
Copy link
Contributor

abattery commented Mar 3, 2021

Do you mean this pattern only, or all other reordering pattern?

It is okay for the newly added patterns only for now.

@google-ml-butler google-ml-butler bot removed the ready to pull PR ready for merge process label Mar 3, 2021
@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Mar 3, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Mar 3, 2021
@WindQAQ
Copy link
Member Author

WindQAQ commented Mar 3, 2021

Also it would be nice to add a test case for >4D input cases.

Because we are restricting input to be defined by FullyConnected, which usually outputs 2D, I cannot really find a use case with 4D input. I do include one test about high-D input in e26e2a7. Let me know if it can pass all internal checks. Thank you!

@rthadur rthadur added this to Assigned Reviewer in PR Queue via automation Mar 3, 2021
@copybara-service copybara-service bot merged commit 2812873 into tensorflow:master Mar 3, 2021
PR Queue automation moved this from Assigned Reviewer to Merged Mar 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull PR ready for merge process size:M CL Change Size: Medium
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet

4 participants