Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TOSA legalization updates for spec v0.22, Part 1 #48193

Merged
merged 3 commits into from Apr 14, 2021

Conversation

sjarus
Copy link

@sjarus sjarus commented Mar 30, 2021

  • Updated gather, gather_nd and resize.
  • Fix precision of TFL avgpool2d, quantize. Add squareddifference
  • Add left/right shift, leaky_relu, one_hot
  • Update relu6/relu_n1_to1
  • Numerical precision in tf.fakequant legalization

Signed-off-by: Suraj Sudhir suraj.sudhir@arm.com
Change-Id: Idf23c3b24342f75ee7d1eb22dab6ffe27e8710b1

- Updated gather, gather_nd and resize.
- Fix precision of TFL avgpool2d, quantize. Add squareddifference
- Add left/right shift, leaky_relu, one_hot
- Update relu6/relu_n1_to1
- Numerical precision in tf.fakequant legalization

Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
Change-Id: Idf23c3b24342f75ee7d1eb22dab6ffe27e8710b1
@google-ml-butler google-ml-butler bot added the size:XL CL Change Size:Extra Large label Mar 30, 2021
@google-cla google-cla bot added the cla: yes label Mar 30, 2021
@sjarus
Copy link
Author

sjarus commented Mar 30, 2021

@stellaraccident and @rsuderman , here are the part 1 set of updates to legalizations from TF/TFLite to TOSA aligned to the recent LLVM side update. That change was picked up by TensorFlow yesterday so we're pushing this out.

@rsuderman
Copy link
Contributor

@stellaraccident and @rsuderman , here are the part 1 set of updates to legalizations from TF/TFLite to TOSA aligned to the recent LLVM side update. That change was picked up by TensorFlow yesterday so we're pushing this out.

Should be fine. We can revalidate internally when we land the CL.

@@ -79,6 +79,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"//tensorflow/core/kernels:fake_quant_ops",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused why this is added. It feels unrelated to the rest of the CL.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It implements tensorflow::Nudge that is invoked from convertFakeQuantOp() in legalize_common.cc

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - i will try to review in detail tonight but have a lot in my queue and it may slip to tomorrow.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add a cross dependency on a large body of TensorFlow to the TFLite path (est. 1000 additional source files?). I would rather just copy the Nudge function locally. (this has been done a couple of times for dependency reasons already)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggestion. Will try to replace it with a local implementation in next round.

@gbaned gbaned self-assigned this Mar 31, 2021
@gbaned gbaned added this to Assigned Reviewer in PR Queue via automation Mar 31, 2021
Copy link

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the contributions. There are a number of comments ranging from local/style, to design to testing. I think the main design point has to do with the override_zero_point attribute and doing that in a way that does not introduce a cross pass dependency. The other big request is upgraded test coverage. I'm happy to jump on some kind of shared space and help educate on some practices here. In addition, I'd suggest looking at the way that some of the legalize_tf tests are done in the tensorflow/xla side for an idea of more of what we are looking for.

auto tfl_avgpool2d_op = cast<TFL::AveragePool2DOp>(op);

auto input_type =
tfl_avgpool2d_op.input().getType().dyn_cast<RankedTensorType>();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_type is constrained by the op definition to be RankedTensorType, right? If that is the case, why dyn_cast (you aren't actually checking if it is !nullptr). Suggest either using cast if this is an invariant or continuing to use dyn_cast and early returning if any dyn_cast returns nullptr.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Will fix in next round.

@@ -128,6 +128,48 @@ struct ConvertUint8QConstOp : public RewritePattern {
}
};

struct ConvertAveragePool2DOp : public RewritePattern {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this an OpRewritePatternTFL::AveragePool2DOp? This would then simplify your constructor and let you remove the cast in the first line of matchAndRewrite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have example for this that we can dumbly follow?


// Annotate attribute to match TFLite average pool2d rounding behavior.
auto key =
mlir::Identifier::get(kOverrideZeropointAttrName, builder.getContext());

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer builder.getIdentifier(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will update in next round.

auto value = builder.getI32IntegerAttr(override_zeropoint);
op->setAttr(key, value);

return success();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still reading how this fits together, but as written, this pattern sketches me out because it is mutating in place and doesn't seem to either produce a new op or alter its match behavior based on what gets set. I don't think this is going to have the desired effect. Will try to discern the goal and offer a suggestion elsewhere.

I can't quite see here why you need an in-place pattern here (versus a helper at the point of transformation which checks these conditions and returns the zero point override).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my comment on top.

if (input_type.getElementType().isa<UniformQuantizedType>()) {
// Search for attribute annotated by --tosa-convert-tfl-uint8 pass.
// This is needed in average pool2d to match TFLite rounding behavior.
IntegerAttr override_zeropoint_attr =

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These kind of cross-pass action at a distance attribute dependencies need to be avoided. Why can't you have a helper which performs the check in the pattern which sets this attribute and just call it here (i.e. a static calculateZeropointOverride(TFL::AveragePool2DOp op)).

Even better you could define the above to return an IntegerAttr conditionally if the input_type is quantized (or nullptr otherwise), then collapse all of the branching in this transform.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explained in the comment on top. QU8 and QI8 is not distinguishable after convert_tfl_uint8 pass.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This piece appears to be the largest blocker for this update. We're aware - as Kevin mentioned - that the approach isn't ideal. Is it ok if we get all other changes done, then do a separate update with a cleaner implementation of this cross pass issue ?

for (int32_t i = 0; i < indices_type.getRank(); i++) {
int32_t dim = indices_type.getShape()[i];
N *= dim;
if (i >= axis)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike LLVM style, google style wants braces around these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add braces in next round.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

axis = indices_type.getRank();
}

int32_t N = 1, W = 1, C = 1;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer just int instead of int32_t unless if there is a strong reason (here and below).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks will update in next round.

// CHECK: tosa.reshape
// CHECK: tosa.transpose
// CHECK: tosa.reshape
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General note on tests (and specific note on this): These are non trivial lowerings and we should have more explicit tests. The tests do not need to verify every detail of the patterns but should minimally capture SSA values and constants. Getting the testing granularity right is a bit of an art, so if you want help on this, we can bust some examples out to a doc or gist or something and work on it together if that is helpful (been meaning to ask for test upgrades for some time and realize there may be some education here).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah so we can check the SSA values as well. Didn't realize we can do that.
I'd assume all we need to do is append that to "// CHECK: tosa.XXX", for example like "// CHECK: tosa.XXX(%0)"?
or it's not that simple as I expect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see similar examples in the tensorflow to hlo lowering tests:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L64

The problem with not validating the large sequences is incorrect constants or changing which values are passed to which ops.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You can drop all the shape information in the line to make things more succinct.

I.e. The following line:

// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

Can be simplified to

// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>}

Generally we want to make sure check tests are mostly a sequence of correct values.

@@ -105,9 +105,16 @@ func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// -----

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't double check, but I think that some of the op conversions added are missing tests (gather?). Please audit test coverage for both ops and critical branches in the patterns.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Will fix in next round.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still missing tests for gather support. Because we are adding functionality for a gather lowering we will need a test to validate the lowering is working correctly.

@@ -79,6 +79,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"//tensorflow/core/kernels:fake_quant_ops",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will add a cross dependency on a large body of TensorFlow to the TFLite path (est. 1000 additional source files?). I would rather just copy the Nudge function locally. (this has been done a couple of times for dependency reasons already)

@armkevincheng
Copy link
Contributor

I agree the cross-pass design is bad, but that was the only way I thought I can achieve what I want.

Briefly explain what I'm doing here:
We're trying to match the rounding behavior to TFL::AveragePool2DOp, which has different rounding behavior between QI8 and QU8 (since it's rounding on its storage type/range).

In the TFL to TOSA pass pipeline, the convert_tfl_uint8 pass is called first, and converts all the QU8 into QI8, so the pass after, e.g. legalize_tfl, only need to deal with QI8, but at the same time can't distinguish between those two cases.

The trick we played here (which is bad I agree) is to annotate the override_zeropoint attribute based on if it's QU8 or QI8 before QU8 is converted. When pass reaches legalize_tfl, it checks if such attribute exists. If it does, then we override that zeropoint with what's stored in the attribute, and we build TOSA::AveragePool2dOp with it.

Could something you mentioned above solve the problem without using the cross-pass design?

@armkevincheng
Copy link
Contributor

Again thank for all the feedbacks. They're all pretty helpful and we'll prepare next round of review as soon as possible.

- Use tensorflow::Nudge implementation locally to avoid build dependency
- Cleanup in coding artifacts
- Disable cross-pass AvgPool2d int8/uint8 handling until better solution

Change-Id: I1e0421ee336e4e137aebfa2856870fa3838169e6
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
@gbaned gbaned added the awaiting review Pull request awaiting review label Apr 8, 2021
@sjarus
Copy link
Author

sjarus commented Apr 9, 2021

We've removed avgpool until the cross pass dependency is resolved, but have updated other things. We hope we can upstream this, and follow up the remaining pieces in a separate PR related to the remaining TOSA v0.22 changes still to follow.

// CHECK: tosa.reshape
// CHECK: tosa.transpose
// CHECK: tosa.reshape
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can see similar examples in the tensorflow to hlo lowering tests:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L64

The problem with not validating the large sequences is incorrect constants or changing which values are passed to which ops.

@@ -105,9 +105,16 @@ func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still missing tests for gather support. Because we are adding functionality for a gather lowering we will need a test to validate the lowering is working correctly.

if (!element_type) return failure();

// In some cases output_type is dynamic shape as tensor<*xelement_type>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kinda of case usually occurs when tensorflow's shape propagation is not working correctly. Passes should assume valid shape propagation as it is an assumed-correct upstream behavior.

alpha = tmpAttr.getValueAsDouble();
}

Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds fine.

PR Queue automation moved this from Assigned Reviewer to Reviewer Requested Changes Apr 12, 2021
@sjarus
Copy link
Author

sjarus commented Apr 12, 2021

We're updating test_one_hot to follow the suggested approach. If that looks good, we'll follow up with a clean up of the entire set of tests in a follow up PR since it is an independent task. However we intend to add the gather tests to this PR, hopefully enabling us to close this one out.

Fixed tfl.quantize legalization output_type setup.

Change-Id: I4a4826f653941299916229d8e2d0e094342f8ef4
Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
@sjarus sjarus requested a review from rsuderman April 12, 2021 22:10
Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last two comments related to tests. Once these wrap we should be good to land.

// CHECK: tosa.reshape
// CHECK: tosa.transpose
// CHECK: tosa.reshape
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: You can drop all the shape information in the line to make things more succinct.

I.e. The following line:

// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>

Can be simplified to

// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>}

Generally we want to make sure check tests are mostly a sequence of correct values.

// CHECK: tosa.reshape
// CHECK: tosa.mul
// CHECK: tosa.reshape
// CHECK: tosa.add
func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perform the same value validation on the rest of the added tests. If they are short (e.g. less than 3-4 ops) its fine without them but most of these tests are pretty complex and we want to avoid future errors.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The plan is to clean up all the tests in a separate PR. Is that a workable option ? We tried the test_one_hot to get feedback on the right way to go about it, rather than change everything, then potentially have to do a second pass at everything.

@sjarus
Copy link
Author

sjarus commented Apr 13, 2021

Please let us know if we can update all the TF + TFL tests in a separate PR, using the suggested template of test_one_hot . Since this is largely independent of this PR and the ones to follow it, we can parallelize the test update work.

@sjarus sjarus requested a review from rsuderman April 14, 2021 16:21
Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can land under the guarantee that the followup appears reasonably soon. We tend to avoid landing tests with brittle checks in case TF canonicalizations are updated. This can cause some unexpected failures.

@google-ml-butler google-ml-butler bot added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Apr 14, 2021
PR Queue automation moved this from Reviewer Requested Changes to Approved by Reviewer Apr 14, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 14, 2021
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Apr 14, 2021
@sjarus
Copy link
Author

sjarus commented Apr 14, 2021

We can land under the guarantee that the followup appears reasonably soon. We tend to avoid landing tests with brittle checks in case TF canonicalizations are updated. This can cause some unexpected failures.

Current estimate is for the updated tests for both TF and TFL legalization to constitute a new PR within the next week. Is that reasonable ?

@sjarus sjarus requested a review from rsuderman April 14, 2021 21:22
PR Queue automation moved this from Approved by Reviewer to Reviewer Requested Changes Apr 14, 2021
@rsuderman
Copy link
Contributor

Found a set of changes required to fix some internal failures, mostly unused variable issues.

@rsuderman
Copy link
Contributor

Working on landing this internally patching the changes in myself.

@sjarus
Copy link
Author

sjarus commented Apr 14, 2021

Working on landing this internally patching the changes in myself.

Oops didn't see this. Thanks for handling this!

@copybara-service copybara-service bot merged commit d9bcf21 into tensorflow:master Apr 14, 2021
PR Queue automation moved this from Reviewer Requested Changes to Merged Apr 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review Pull request awaiting review cla: yes ready to pull PR ready for merge process size:XL CL Change Size:Extra Large
Projects
PR Queue
  
Merged
Development

Successfully merging this pull request may close these issues.

None yet

6 participants