-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tosa] Add legalization of BroadcastTo #60692
Conversation
Hi @Tai78641 Can you please take a look on below internal errors. Thank you! TensorFlow crashed, please file a bug on https://github.com/tensorflow/tensorflow/issues with the trace below. |
8da0835
to
3ed47c9
Compare
@gbaned I rebased and double checked. tf-to-tosa-pipeline.mlir passes for me. |
@jpienaar I cannot figure out why the Py+CPP Test Suite failures. |
3ed47c9
to
0f18bab
Compare
rebased to use dyn_cast(...) style |
int32_t num_elements = 1; | ||
SmallVector<int64_t> new_shape; | ||
for (int i = 0; i < shape_rank; i++) { | ||
auto shape_dim = shape_elems.getValues<IntegerAttr>()[i].getInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has TOSA been moved to properties for shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I understand the question, which probably means the answer is no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Jacques is referring to using shape_elems
. Generally we try to use the getType()
value as its "assumed to be true" value. Obviously this is TFLite so either the shape_elems
or the getType()
could be under defined. I would recommend pulling both and consolidating the correct value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot this was TFL, I was thinking about requiring IntegerAttr contruction here which could be avoided by properties.
0f18bab
to
c650975
Compare
|
||
if (element_type.isa<FloatType>()) { | ||
// F32: legalize to broadcastable Add with (-0.f) | ||
std::vector<float> values(num_elements, -0.f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency you should change over to llvm::SmallVector>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
int32_t num_elements = 1; | ||
SmallVector<int64_t> new_shape; | ||
for (int i = 0; i < shape_rank; i++) { | ||
auto shape_dim = shape_elems.getValues<IntegerAttr>()[i].getInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Jacques is referring to using shape_elems
. Generally we try to use the getType()
value as its "assumed to be true" value. Obviously this is TFLite so either the shape_elems
or the getType()
could be under defined. I would recommend pulling both and consolidating the correct value.
For tf/tfl broadcast-to operator with input and shape where: - shape is compile time constant, and - shape's rank is greater than or equal to input's rank, and - input element type is not complex, and - input element type is not integer whose bitwidth is greater than 32 will convert to tosa operators as follows: 1. if input element type is floating point, add input with constant -0.f of the broadcast shape 2. if input element type is i1, logical-or input with constant 'false' of the broadcast shape 3. if input element type is i32, add input with constant 0 (i32) of the broadcast shape 4. otherwise, cast input to i32, add with constant 0 (i32) of the broadcast shape, and cast back to original element type added tf/tfl lit tests Signed-off-by: Tai Ly <tai.ly@arm.com> Change-Id: I12302adcf1c791d452a5b5a928e63e5ffcd523bc
c650975
to
eb9e49a
Compare
op, | ||
"shape's constant value has different elements than its static " | ||
"dimension"); | ||
return std::nullopt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added this check that shape_elems (derived from compile time constant for input "shape") is consistent with the getType() for input "shape".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be part of the op verification? (I mean I think this is TFL side, so you could add a TODO here and I could ping folks)
@rsuderman please have a look at this again. thanks |
op, | ||
"shape's constant value has different elements than its static " | ||
"dimension"); | ||
return std::nullopt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be part of the op verification? (I mean I think this is TFL side, so you could add a TODO here and I could ping folks)
// reshape input to shape_rank | ||
SmallVector<int64_t> reshaped_shape((shape_rank - input_rank), 1); | ||
for (auto dim : input_type.getShape()) { | ||
reshaped_shape.push_back(dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could also use insert method on SmallVector. (append_range helper in LLVM could also be useful)
aec0be7
into
tensorflow:master
I had to make some modifications to land. I believe it should still be good but you may need to validate that things are working as intended. |
For tf/tfl boradcast-to operator with input and shape where:
will convert to tosa operators as follows:
added tf/tfl lit tests
Change-Id: I12302adcf1c791d452a5b5a928e63e5ffcd523bc