Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFTRT: Add ExpandDims, Squeeze ops and unit tests. #23909

Conversation

trevor-m
Copy link
Contributor

Add conversion for ExpandDims, Squeeze ops, and unit tests for both.

These ops will allow 1D convolutions to be converted.

@trevor-m trevor-m changed the title Add ExpandDims, Squeeze ops and unit tests. TFTRT: Add ExpandDims, Squeeze ops and unit tests. Nov 21, 2018
@Harshini-Gadige Harshini-Gadige added the awaiting review Pull request awaiting review label Nov 26, 2018
tensorflow/contrib/tensorrt/convert/convert_nodes.cc Outdated Show resolved Hide resolved
tensorflow/contrib/tensorrt/convert/convert_nodes.cc Outdated Show resolved Hide resolved
tensorflow/contrib/tensorrt/convert/convert_nodes.cc Outdated Show resolved Hide resolved
const int input_rank = input_dims.size();
// Mark axes to remove by setting them to 0.
TFAttrs attrs(node_def);
auto squeeze_dims = attrs.get<std::vector<int>>("squeeze_dims");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squeeze_dims is deprecated, we should use axis, see https://www.tensorflow.org/api_docs/python/tf/squeeze.

Copy link
Contributor Author

@trevor-m trevor-m Nov 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the API now uses axis, it appears the NodeDef.attr() only contains squeeze_dims, even though when creating the node def attributes we use the .axis_ field: https://github.com/tensorflow/tensorflow/pull/23909/files#diff-5b5e9a9d60ab4ad5474d419d568fb6faR2110

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out!

@tensorflowbutler tensorflowbutler removed the awaiting review Pull request awaiting review label Nov 30, 2018
@trevor-m trevor-m force-pushed the tmorris_tftrt_expanddims_squeeze_ops branch from f51f7b8 to 9284cda Compare November 30, 2018 21:34
@Harshini-Gadige Harshini-Gadige added the awaiting review Pull request awaiting review label Dec 4, 2018
tensorflow::Status ConvertExpandDims(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 2) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't axis an attribute?

@pooyadavoodi
Copy link

@azaks2 @smit-hinsu
Could you please review.

@googlebot
Copy link

So there's good news and bad news.

👍 The good news is that everyone that needs to sign a CLA (the pull request submitter and all commit authors) have done so. Everything is all good there.

😕 The bad news is that it appears that one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that here in the pull request.

Note to project maintainer: This is a terminal state, meaning the cla/google commit status will not change from this state. It's up to you to confirm consent of all the commit author(s), set the cla label to yes (if enabled on your project), and then merge this pull request when appropriate.

@googlebot googlebot added cla: no and removed cla: yes labels Dec 5, 2018
@alextp
Copy link
Contributor

alextp commented Dec 5, 2018

Can you rebase this PR wrt master? It's a little hard to review as-is

@trevor-m trevor-m force-pushed the tmorris_tftrt_expanddims_squeeze_ops branch from 4f742ca to 06fe3c8 Compare December 5, 2018 19:39
@googlebot
Copy link

CLAs look good, thanks!

@googlebot googlebot added cla: yes and removed cla: no labels Dec 5, 2018
@alextp alextp removed their request for review December 5, 2018 19:47
}
// Convert negative axis to corresponding positive axis.
if (axis < 0) axis += input_rank + 1;
if (input_tensor.is_tensor() && axis == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_tensor.is_tensor() must be true now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to leave it for now? We will need to support strided slice for weights later for dynamic reshaping.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If so, then please leave a comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh whoops, got this mixed up with the other PR. Removing these.

tensorflow/contrib/tensorrt/convert/convert_nodes.cc Outdated Show resolved Hide resolved
const int input_rank = input_dims.size();
// Mark axes to remove by setting them to 0.
TFAttrs attrs(node_def);
auto squeeze_dims = attrs.get<std::vector<int>>("squeeze_dims");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out!

@aaroey aaroey added kokoro:force-run Tests on submitted change and removed awaiting review Pull request awaiting review labels Dec 5, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Dec 5, 2018
aaroey
aaroey previously approved these changes Dec 5, 2018
Copy link
Member

@aaroey aaroey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@aaroey aaroey added kokoro:force-run Tests on submitted change ready to pull PR ready for merge process labels Dec 5, 2018
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Dec 5, 2018
@smit-hinsu smit-hinsu removed their request for review December 6, 2018 08:13
@aaroey
Copy link
Member

aaroey commented Dec 6, 2018

Hi @trevor-m, would you help to fix this error:

error: ignoring return value of function declared with 'warn_unused_result' attribute [-Werror,-Wunused-result]
  TensorShapeUtils::MakeShape(shape, &tensor_shape);

Thanks.

@trevor-m
Copy link
Contributor Author

trevor-m commented Dec 6, 2018

Hi @trevor-m, would you help to fix this error:

error: ignoring return value of function declared with 'warn_unused_result' attribute [-Werror,-Wunused-result]
  TensorShapeUtils::MakeShape(shape, &tensor_shape);

Thanks.

Done. Thanks for reviewing!

aaroey
aaroey previously approved these changes Dec 6, 2018
@aaroey
Copy link
Member

aaroey commented Dec 6, 2018

Thanks @trevor-m. I found that rank_two_test and unary_test are failing, probably because they're using squeeze ops, would you help to double check? If you want to add an op that is unsupported you may use self.trt_incompatible_op.

@tensorflow-copybara tensorflow-copybara merged commit fc4b184 into tensorflow:master Dec 7, 2018
tensorflow-copybara pushed a commit that referenced this pull request Dec 7, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull PR ready for merge process
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants