-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TFL] Enhance MoveBinaryOpBeforeReshape pattern to enable fusion of 2D bias. #47510
[TFL] Enhance MoveBinaryOpBeforeReshape pattern to enable fusion of 2D bias. #47510
Conversation
Some uses of EinsumDense layer will create 2D or higher bias. For example, ```python layer = tf.keras.layers.MultiHeadAttention(num_heads=3, key_dim=5) target = tf.keras.Input(shape=[8, 16], batch_size=1) source = tf.keras.Input(shape=[4, 16], batch_size=1) output_tensor = layer(target, source, return_attention_scores=False) model = tf.keras.Model([target, source], output_tensor) ``` This PR reorder Reshape and BinaryOp to make 2D bias flatten and fusable to FullyConnected.
@@ -642,6 +642,34 @@ func @NotReorderReshapeAddIfHighDim(%arg0: tensor<1x1x1x1x30x96xf32>) -> tensor< | |||
// CHECK: return %[[rs2]] | |||
} | |||
|
|||
// CHECK-LABEL: @ReorderReshapeAdd2DConst |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add test cases for reordering and fusing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. I follow the test cases of FuseFullyConnectedReshapeAddConst*
. Let me know if that's applicable. Thank you!
Thanks for the contributions! |
Please add additional checks in the new patterns to make sure that the newly generated binary ops will have <=4D inputs since TFLite binary kernels support broadcasting up to 4D inputs. |
Also it would be nice to add a test case for >4D input cases. |
Also can we limit the reorderings only when the FullyConnected op is appeared next to it? |
Do you mean this pattern only, or all other reordering pattern? |
It is okay for the newly added patterns only for now. |
Because we are restricting input to be defined by FullyConnected, which usually outputs 2D, I cannot really find a use case with 4D input. I do include one test about high-D input in e26e2a7. Let me know if it can pass all internal checks. Thank you! |
Some uses of EinsumDense layer will create 2D or higher bias. For example,
This PR reorders Reshape and BinaryOp to make 2D bias flatten and fusable to FullyConnected. In the case above, bias add can be fused into FullyConnected at location 0, 2 and 172.