-
Notifications
You must be signed in to change notification settings - Fork 74.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[tosa] Legalize tfl.atan2 to tosa.table operation #59100
[tosa] Legalize tfl.atan2 to tosa.table operation #59100
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to hit send on comments ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review @jpienaar! I left some responses below
b633a12
to
499bf34
Compare
Hi @lhutton1 Can you please resolve conflicts? Thank you! |
499bf34
to
bdd550e
Compare
Friendly ping @jpienaar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping
@@ -640,6 +640,47 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { | |||
|
|||
// ----- | |||
|
|||
// CHECK-LABEL: test_atan2 | |||
// CHECK-SAME: -> tensor<13x21x3xf32> | |||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer CHECK-DAG for constants (there is an internal check for this that ends up blocking presubmits, but seems not upstream)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done thanks
|
||
Type input_y_ety = input_y_ty.getElementType(); | ||
Type input_x_ety = input_x_ty.getElementType(); | ||
Type output_ety = output_ty.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check for output_ty above too to avoid null here or just use cast above - I don't think TFL::Atan2Op is legal unless it produces a ShapedType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops seems I left these in after the last review - removed them
auto tfl_atan2_op = cast<TFL::Atan2Op>(op); | ||
Location loc = op->getLoc(); | ||
Value input_y = tfl_atan2_op.getY(); | ||
RankedTensorType input_y_ty = input_y.getType().dyn_cast<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the new casting mechanisms in LLVM one can now finally do
auto input_y_ty = dyn_cast(input_y.getType());
prefer to not repeat type for these (for those that are outputs of cast's we use auto as type is clear locally).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks cleaner. I might have misunderstood the dyn_cast(...)
suggestion here, but I was only able to get it working with a templated type dyn_cast<RankedTensorType>(...)
Type input_x_ety = input_x_ty.getElementType(); | ||
Type output_ety = output_ty.getElementType(); | ||
|
||
bool op_is_fp = input_y_ty.getElementType().isF32(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use input_y_ety?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed as they are only used once anyway
|
||
if (!input_y_ty || !input_x_ty) { | ||
return rewriter.notifyMatchFailure( | ||
op, "ConvertTFLAtan2Op: ranked inputs required"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update these to keep it consistent with the others (skipping the pattern name in error message)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - apologies, I realized I missed this from your previous review
|
||
// 2. Scale and translate the normalized domain to the table domain. This | ||
// includes a translating and scaling to [-int16_max, int16_max] and casting | ||
// to an i16. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add reason for casting here too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, | ||
table_const); | ||
|
||
// 4. The range of table is a 23-bit two's compliment value. Normalize the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks!
In a similar fashion to tfl.sin and tlf.cos legalization, a tosa.table provides an atan approximation. Additional logic is then used to determine the correct quadrant of the atan2 function. Signed-off-by: Luke Hutton <luke.hutton@arm.com> Change-Id: Iae1384009d825d01e5cf48ad7c3ff8fba77114cf
* remove unnecessary type checks * add note about numerical behaviour of std::atan2 * improve error message for expected inputs * undo change updating copyright year Change-Id: Iea8339da437a5ff3e6fe065c715c2c97e696fdbb
* Cleanup of error messages. * Cleanup casting inputs/outputs. * Use CHECK-DAG instead of CHECK for constants. * Spelling Change-Id: Ibf01ac5b711944bb2efbd7184820a304cbae8501 Signed-off-by: Luke Hutton <luke.hutton@arm.com>
bdd550e
to
3986b07
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the late follow-up on this and thanks for the review @jpienaar, please find my replies below
rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, | ||
table_const); | ||
|
||
// 4. The range of table is a 23-bit two's compliment value. Normalize the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks!
|
||
Type input_y_ety = input_y_ty.getElementType(); | ||
Type input_x_ety = input_x_ty.getElementType(); | ||
Type output_ety = output_ty.getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops seems I left these in after the last review - removed them
Type input_x_ety = input_x_ty.getElementType(); | ||
Type output_ety = output_ty.getElementType(); | ||
|
||
bool op_is_fp = input_y_ty.getElementType().isF32(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed as they are only used once anyway
auto tfl_atan2_op = cast<TFL::Atan2Op>(op); | ||
Location loc = op->getLoc(); | ||
Value input_y = tfl_atan2_op.getY(); | ||
RankedTensorType input_y_ty = input_y.getType().dyn_cast<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks cleaner. I might have misunderstood the dyn_cast(...)
suggestion here, but I was only able to get it working with a templated type dyn_cast<RankedTensorType>(...)
|
||
if (!input_y_ty || !input_x_ty) { | ||
return rewriter.notifyMatchFailure( | ||
op, "ConvertTFLAtan2Op: ranked inputs required"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done - apologies, I realized I missed this from your previous review
|
||
// 2. Scale and translate the normalized domain to the table domain. This | ||
// includes a translating and scaling to [-int16_max, int16_max] and casting | ||
// to an i16. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -640,6 +640,47 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { | |||
|
|||
// ----- | |||
|
|||
// CHECK-LABEL: test_atan2 | |||
// CHECK-SAME: -> tensor<13x21x3xf32> | |||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done thanks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, check on consistency part and if the behavior TOSA side in case there were 0's but beyond that good. (and I saw your responses later, don't understand github PRs some days)
Thanks! It was actually me who forgot to hit send 😅 |
In a similar fashion to tfl.sin and tlf.cos legalization, a tosa.table provides an atan approximation. Additional logic is then used to determine the correct quadrant of the atan2 function.
Signed-off-by: Luke Hutton luke.hutton@arm.com