Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tosa] Legalize tfl.atan2 to tosa.table operation #59100

Merged

Conversation

lhutton1
Copy link
Contributor

@lhutton1 lhutton1 commented Jan 4, 2023

In a similar fashion to tfl.sin and tlf.cos legalization, a tosa.table provides an atan approximation. Additional logic is then used to determine the correct quadrant of the atan2 function.

Signed-off-by: Luke Hutton luke.hutton@arm.com

@google-ml-butler google-ml-butler bot added the size:M CL Change Size: Medium label Jan 4, 2023
@gbaned gbaned requested a review from jpienaar January 5, 2023 15:43
@google-ml-butler google-ml-butler bot added the awaiting review Pull request awaiting review label Jan 5, 2023
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to hit send on comments ...

Copy link
Contributor Author

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @jpienaar! I left some responses below

@gbaned
Copy link
Contributor

gbaned commented Jan 11, 2023

Hi @lhutton1 Can you please resolve conflicts? Thank you!

@gbaned gbaned added stat:awaiting response Status - Awaiting response from author and removed awaiting review Pull request awaiting review labels Jan 11, 2023
@gbaned gbaned added awaiting review Pull request awaiting review and removed stat:awaiting response Status - Awaiting response from author labels Jan 12, 2023
@lhutton1
Copy link
Contributor Author

Friendly ping @jpienaar

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping

@@ -640,6 +640,47 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> {

// -----

// CHECK-LABEL: test_atan2
// CHECK-SAME: -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer CHECK-DAG for constants (there is an internal check for this that ends up blocking presubmits, but seems not upstream)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done thanks


Type input_y_ety = input_y_ty.getElementType();
Type input_x_ety = input_x_ty.getElementType();
Type output_ety = output_ty.getElementType();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for output_ty above too to avoid null here or just use cast above - I don't think TFL::Atan2Op is legal unless it produces a ShapedType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops seems I left these in after the last review - removed them

auto tfl_atan2_op = cast<TFL::Atan2Op>(op);
Location loc = op->getLoc();
Value input_y = tfl_atan2_op.getY();
RankedTensorType input_y_ty = input_y.getType().dyn_cast<RankedTensorType>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new casting mechanisms in LLVM one can now finally do

auto input_y_ty = dyn_cast(input_y.getType());

prefer to not repeat type for these (for those that are outputs of cast's we use auto as type is clear locally).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks cleaner. I might have misunderstood the dyn_cast(...) suggestion here, but I was only able to get it working with a templated type dyn_cast<RankedTensorType>(...)

Type input_x_ety = input_x_ty.getElementType();
Type output_ety = output_ty.getElementType();

bool op_is_fp = input_y_ty.getElementType().isF32();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use input_y_ety?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed as they are only used once anyway


if (!input_y_ty || !input_x_ty) {
return rewriter.notifyMatchFailure(
op, "ConvertTFLAtan2Op: ranked inputs required");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update these to keep it consistent with the others (skipping the pattern name in error message)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - apologies, I realized I missed this from your previous review


// 2. Scale and translate the normalized domain to the table domain. This
// includes a translating and scaling to [-int16_max, int16_max] and casting
// to an i16.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add reason for casting here too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted,
table_const);

// 4. The range of table is a 23-bit two's compliment value. Normalize the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks!

@gbaned gbaned removed the awaiting review Pull request awaiting review label Jan 18, 2023
In a similar fashion to tfl.sin and tlf.cos legalization,
a tosa.table provides an atan approximation. Additional logic
is then used to determine the correct quadrant of the atan2
function.

Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Change-Id: Iae1384009d825d01e5cf48ad7c3ff8fba77114cf
* remove unnecessary type checks
* add note about numerical behaviour of std::atan2
* improve error message for expected inputs
* undo change updating copyright year

Change-Id: Iea8339da437a5ff3e6fe065c715c2c97e696fdbb
* Cleanup of error messages.
* Cleanup casting inputs/outputs.
* Use CHECK-DAG instead of CHECK for constants.
* Spelling

Change-Id: Ibf01ac5b711944bb2efbd7184820a304cbae8501
Signed-off-by: Luke Hutton <luke.hutton@arm.com>
Copy link
Contributor Author

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the late follow-up on this and thanks for the review @jpienaar, please find my replies below

rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted,
table_const);

// 4. The range of table is a 23-bit two's compliment value. Normalize the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks!


Type input_y_ety = input_y_ty.getElementType();
Type input_x_ety = input_x_ty.getElementType();
Type output_ety = output_ty.getElementType();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops seems I left these in after the last review - removed them

Type input_x_ety = input_x_ty.getElementType();
Type output_ety = output_ty.getElementType();

bool op_is_fp = input_y_ty.getElementType().isF32();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed as they are only used once anyway

auto tfl_atan2_op = cast<TFL::Atan2Op>(op);
Location loc = op->getLoc();
Value input_y = tfl_atan2_op.getY();
RankedTensorType input_y_ty = input_y.getType().dyn_cast<RankedTensorType>();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks cleaner. I might have misunderstood the dyn_cast(...) suggestion here, but I was only able to get it working with a templated type dyn_cast<RankedTensorType>(...)


if (!input_y_ty || !input_x_ty) {
return rewriter.notifyMatchFailure(
op, "ConvertTFLAtan2Op: ranked inputs required");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done - apologies, I realized I missed this from your previous review


// 2. Scale and translate the normalized domain to the table domain. This
// includes a translating and scaling to [-int16_max, int16_max] and casting
// to an i16.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -640,6 +640,47 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> {

// -----

// CHECK-LABEL: test_atan2
// CHECK-SAME: -> tensor<13x21x3xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<2.000000e+00> : tensor<1x1x1xf32>} : () -> tensor<1x1x1xf32>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done thanks

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, check on consistency part and if the behavior TOSA side in case there were 0's but beyond that good. (and I saw your responses later, don't understand github PRs some days)

@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Jan 26, 2023
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Jan 26, 2023
@lhutton1
Copy link
Contributor Author

Thanks! It was actually me who forgot to hit send 😅

@copybara-service copybara-service bot merged commit 0e89f1a into tensorflow:master Jan 28, 2023
@lhutton1 lhutton1 deleted the tosa-atan2-tfl-legalization branch January 28, 2023 16:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready to pull PR ready for merge process size:M CL Change Size: Medium
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants