-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TFTRT: Add ExpandDims, Squeeze ops and unit tests. #23909
TFTRT: Add ExpandDims, Squeeze ops and unit tests. #23909
Conversation
const int input_rank = input_dims.size(); | ||
// Mark axes to remove by setting them to 0. | ||
TFAttrs attrs(node_def); | ||
auto squeeze_dims = attrs.get<std::vector<int>>("squeeze_dims"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squeeze_dims
is deprecated, we should use axis
, see https://www.tensorflow.org/api_docs/python/tf/squeeze.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While the API now uses axis
, it appears the NodeDef.attr() only contains squeeze_dims
, even though when creating the node def attributes we use the .axis_
field: https://github.com/tensorflow/tensorflow/pull/23909/files#diff-5b5e9a9d60ab4ad5474d419d568fb6faR2110
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See Op definition for Squeeze: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/array_ops.cc#L1980
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out!
f51f7b8
to
9284cda
Compare
tensorflow::Status ConvertExpandDims(OpConverterParams* params) { | ||
const auto& inputs = params->inputs; | ||
const auto& node_def = params->node_def; | ||
if (inputs.size() != 2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't axis
an attribute?
@azaks2 @smit-hinsu |
So there's good news and bad news. 👍 The good news is that everyone that needs to sign a CLA (the pull request submitter and all commit authors) have done so. Everything is all good there. 😕 The bad news is that it appears that one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that here in the pull request. Note to project maintainer: This is a terminal state, meaning the |
Can you rebase this PR wrt master? It's a little hard to review as-is |
4f742ca
to
06fe3c8
Compare
CLAs look good, thanks! |
} | ||
// Convert negative axis to corresponding positive axis. | ||
if (axis < 0) axis += input_rank + 1; | ||
if (input_tensor.is_tensor() && axis == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_tensor.is_tensor()
must be true now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it ok to leave it for now? We will need to support strided slice for weights later for dynamic reshaping.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If so, then please leave a comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh whoops, got this mixed up with the other PR. Removing these.
const int input_rank = input_dims.size(); | ||
// Mark axes to remove by setting them to 0. | ||
TFAttrs attrs(node_def); | ||
auto squeeze_dims = attrs.get<std::vector<int>>("squeeze_dims"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Hi @trevor-m, would you help to fix this error:
Thanks. |
Done. Thanks for reviewing! |
Thanks @trevor-m. I found that rank_two_test and unary_test are failing, probably because they're using squeeze ops, would you help to double check? If you want to add an op that is unsupported you may use |
…eze_ops PiperOrigin-RevId: 224471142
Add conversion for ExpandDims, Squeeze ops, and unit tests for both.
These ops will allow 1D convolutions to be converted.