New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TOSA legalization updates for spec v0.22, Part 1 #48193
Conversation
- Updated gather, gather_nd and resize. - Fix precision of TFL avgpool2d, quantize. Add squareddifference - Add left/right shift, leaky_relu, one_hot - Update relu6/relu_n1_to1 - Numerical precision in tf.fakequant legalization Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com> Change-Id: Idf23c3b24342f75ee7d1eb22dab6ffe27e8710b1
@stellaraccident and @rsuderman , here are the part 1 set of updates to legalizations from TF/TFLite to TOSA aligned to the recent LLVM side update. That change was picked up by TensorFlow yesterday so we're pushing this out. |
Should be fine. We can revalidate internally when we land the CL. |
tensorflow/compiler/mlir/tosa/BUILD
Outdated
@@ -79,6 +79,7 @@ cc_library( | |||
"//tensorflow/compiler/mlir/lite:tensorflow_lite", | |||
"//tensorflow/core:framework", | |||
"//tensorflow/core/kernels:conv_grad_shape_utils", | |||
"//tensorflow/core/kernels:fake_quant_ops", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a little confused why this is added. It feels unrelated to the rest of the CL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It implements tensorflow::Nudge that is invoked from convertFakeQuantOp() in legalize_common.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks - i will try to review in detail tonight but have a lot in my queue and it may slip to tomorrow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will add a cross dependency on a large body of TensorFlow to the TFLite path (est. 1000 additional source files?). I would rather just copy the Nudge function locally. (this has been done a couple of times for dependency reasons already)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggestion. Will try to replace it with a local implementation in next round.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the contributions. There are a number of comments ranging from local/style, to design to testing. I think the main design point has to do with the override_zero_point
attribute and doing that in a way that does not introduce a cross pass dependency. The other big request is upgraded test coverage. I'm happy to jump on some kind of shared space and help educate on some practices here. In addition, I'd suggest looking at the way that some of the legalize_tf
tests are done in the tensorflow/xla side for an idea of more of what we are looking for.
auto tfl_avgpool2d_op = cast<TFL::AveragePool2DOp>(op); | ||
|
||
auto input_type = | ||
tfl_avgpool2d_op.input().getType().dyn_cast<RankedTensorType>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_type
is constrained by the op definition to be RankedTensorType
, right? If that is the case, why dyn_cast
(you aren't actually checking if it is !nullptr). Suggest either using cast
if this is an invariant or continuing to use dyn_cast
and early returning if any dyn_cast
returns nullptr.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My bad. Will fix in next round.
@@ -128,6 +128,48 @@ struct ConvertUint8QConstOp : public RewritePattern { | |||
} | |||
}; | |||
|
|||
struct ConvertAveragePool2DOp : public RewritePattern { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't this an OpRewritePatternTFL::AveragePool2DOp? This would then simplify your constructor and let you remove the cast in the first line of matchAndRewrite
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have example for this that we can dumbly follow?
|
||
// Annotate attribute to match TFLite average pool2d rounding behavior. | ||
auto key = | ||
mlir::Identifier::get(kOverrideZeropointAttrName, builder.getContext()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer builder.getIdentifier(...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will update in next round.
auto value = builder.getI32IntegerAttr(override_zeropoint); | ||
op->setAttr(key, value); | ||
|
||
return success(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still reading how this fits together, but as written, this pattern sketches me out because it is mutating in place and doesn't seem to either produce a new op or alter its match behavior based on what gets set. I don't think this is going to have the desired effect. Will try to discern the goal and offer a suggestion elsewhere.
I can't quite see here why you need an in-place pattern here (versus a helper at the point of transformation which checks these conditions and returns the zero point override).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see my comment on top.
if (input_type.getElementType().isa<UniformQuantizedType>()) { | ||
// Search for attribute annotated by --tosa-convert-tfl-uint8 pass. | ||
// This is needed in average pool2d to match TFLite rounding behavior. | ||
IntegerAttr override_zeropoint_attr = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These kind of cross-pass action at a distance attribute dependencies need to be avoided. Why can't you have a helper which performs the check in the pattern which sets this attribute and just call it here (i.e. a static calculateZeropointOverride(TFL::AveragePool2DOp op)
).
Even better you could define the above to return an IntegerAttr
conditionally if the input_type is quantized (or nullptr otherwise), then collapse all of the branching in this transform.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explained in the comment on top. QU8 and QI8 is not distinguishable after convert_tfl_uint8 pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This piece appears to be the largest blocker for this update. We're aware - as Kevin mentioned - that the approach isn't ideal. Is it ok if we get all other changes done, then do a separate update with a cleaner implementation of this cross pass issue ?
for (int32_t i = 0; i < indices_type.getRank(); i++) { | ||
int32_t dim = indices_type.getShape()[i]; | ||
N *= dim; | ||
if (i >= axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unlike LLVM style, google style wants braces around these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add braces in next round.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
axis = indices_type.getRank(); | ||
} | ||
|
||
int32_t N = 1, W = 1, C = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer just int
instead of int32_t
unless if there is a strong reason (here and below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks will update in next round.
// CHECK: tosa.reshape | ||
// CHECK: tosa.transpose | ||
// CHECK: tosa.reshape | ||
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General note on tests (and specific note on this): These are non trivial lowerings and we should have more explicit tests. The tests do not need to verify every detail of the patterns but should minimally capture SSA values and constants. Getting the testing granularity right is a bit of an art, so if you want help on this, we can bust some examples out to a doc or gist or something and work on it together if that is helpful (been meaning to ask for test upgrades for some time and realize there may be some education here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah so we can check the SSA values as well. Didn't realize we can do that.
I'd assume all we need to do is append that to "// CHECK: tosa.XXX", for example like "// CHECK: tosa.XXX(%0)"?
or it's not that simple as I expect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can see similar examples in the tensorflow to hlo lowering tests:
The problem with not validating the large sequences is incorrect constants or changing which values are passed to which ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You can drop all the shape information in the line to make things more succinct.
I.e. The following line:
// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
Can be simplified to
// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>}
Generally we want to make sure check tests are mostly a sequence of correct values.
@@ -105,9 +105,16 @@ func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { | |||
// ----- | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't double check, but I think that some of the op conversions added are missing tests (gather?). Please audit test coverage for both ops and critical branches in the patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Will fix in next round.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still missing tests for gather support. Because we are adding functionality for a gather lowering we will need a test to validate the lowering is working correctly.
tensorflow/compiler/mlir/tosa/BUILD
Outdated
@@ -79,6 +79,7 @@ cc_library( | |||
"//tensorflow/compiler/mlir/lite:tensorflow_lite", | |||
"//tensorflow/core:framework", | |||
"//tensorflow/core/kernels:conv_grad_shape_utils", | |||
"//tensorflow/core/kernels:fake_quant_ops", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will add a cross dependency on a large body of TensorFlow to the TFLite path (est. 1000 additional source files?). I would rather just copy the Nudge function locally. (this has been done a couple of times for dependency reasons already)
I agree the cross-pass design is bad, but that was the only way I thought I can achieve what I want. Briefly explain what I'm doing here: In the TFL to TOSA pass pipeline, the convert_tfl_uint8 pass is called first, and converts all the QU8 into QI8, so the pass after, e.g. legalize_tfl, only need to deal with QI8, but at the same time can't distinguish between those two cases. The trick we played here (which is bad I agree) is to annotate the override_zeropoint attribute based on if it's QU8 or QI8 before QU8 is converted. When pass reaches legalize_tfl, it checks if such attribute exists. If it does, then we override that zeropoint with what's stored in the attribute, and we build TOSA::AveragePool2dOp with it. Could something you mentioned above solve the problem without using the cross-pass design? |
Again thank for all the feedbacks. They're all pretty helpful and we'll prepare next round of review as soon as possible. |
- Use tensorflow::Nudge implementation locally to avoid build dependency - Cleanup in coding artifacts - Disable cross-pass AvgPool2d int8/uint8 handling until better solution Change-Id: I1e0421ee336e4e137aebfa2856870fa3838169e6 Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
We've removed avgpool until the cross pass dependency is resolved, but have updated other things. We hope we can upstream this, and follow up the remaining pieces in a separate PR related to the remaining TOSA v0.22 changes still to follow. |
// CHECK: tosa.reshape | ||
// CHECK: tosa.transpose | ||
// CHECK: tosa.reshape | ||
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can see similar examples in the tensorflow to hlo lowering tests:
The problem with not validating the large sequences is incorrect constants or changing which values are passed to which ops.
@@ -105,9 +105,16 @@ func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { | |||
// ----- | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still missing tests for gather support. Because we are adding functionality for a gather lowering we will need a test to validate the lowering is working correctly.
if (!element_type) return failure(); | ||
|
||
// In some cases output_type is dynamic shape as tensor<*xelement_type> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kinda of case usually occurs when tensorflow's shape propagation is not working correctly. Passes should assume valid shape propagation as it is an assumed-correct upstream behavior.
alpha = tmpAttr.getValueAsDouble(); | ||
} | ||
|
||
Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine.
We're updating test_one_hot to follow the suggested approach. If that looks good, we'll follow up with a clean up of the entire set of tests in a follow up PR since it is an independent task. However we intend to add the gather tests to this PR, hopefully enabling us to close this one out. |
Fixed tfl.quantize legalization output_type setup. Change-Id: I4a4826f653941299916229d8e2d0e094342f8ef4 Signed-off-by: Suraj Sudhir <suraj.sudhir@arm.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last two comments related to tests. Once these wrap we should be good to land.
// CHECK: tosa.reshape | ||
// CHECK: tosa.transpose | ||
// CHECK: tosa.reshape | ||
func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<4x4x2xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You can drop all the shape information in the line to make things more succinct.
I.e. The following line:
// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>} : () -> tensor<3xi32>
Can be simplified to
// CHECK: %[[VAR0:.*]] = "tosa.const"() {value = dense<[0, 2, 1]> : tensor<3xi32>}
Generally we want to make sure check tests are mostly a sequence of correct values.
// CHECK: tosa.reshape | ||
// CHECK: tosa.mul | ||
// CHECK: tosa.reshape | ||
// CHECK: tosa.add | ||
func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perform the same value validation on the rest of the added tests. If they are short (e.g. less than 3-4 ops) its fine without them but most of these tests are pretty complex and we want to avoid future errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The plan is to clean up all the tests in a separate PR. Is that a workable option ? We tried the test_one_hot to get feedback on the right way to go about it, rather than change everything, then potentially have to do a second pass at everything.
Please let us know if we can update all the TF + TFL tests in a separate PR, using the suggested template of test_one_hot . Since this is largely independent of this PR and the ones to follow it, we can parallelize the test update work. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can land under the guarantee that the followup appears reasonably soon. We tend to avoid landing tests with brittle checks in case TF canonicalizations are updated. This can cause some unexpected failures.
Current estimate is for the updated tests for both TF and TFL legalization to constitute a new PR within the next week. Is that reasonable ? |
Found a set of changes required to fix some internal failures, mostly unused variable issues. |
Working on landing this internally patching the changes in myself. |
Oops didn't see this. Thanks for handling this! |
Signed-off-by: Suraj Sudhir suraj.sudhir@arm.com
Change-Id: Idf23c3b24342f75ee7d1eb22dab6ffe27e8710b1