Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Add fast path cases for common scatter and gather operations #15185

Merged
merged 2 commits into from
Dec 14, 2017

Conversation

DavidNorman
Copy link
Contributor

This change checks if the indices vector passed to a scatter or gather operation is a constant, and does a fast-path operation when it is filled with a zero-based incrementing set.

This is quite a common case because of tensor-array stack and unstack.

@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

@DavidNorman
Copy link
Contributor Author

although there are existing unit tests that catch this change, do you think I should add an explicit device-targetted test which demonstrates this working. It is hard to check that the fast path has been used, without checking the form of the XLA HLO graph - but this could be done in the XLA CC tests.

@caisq caisq requested a review from hawkinsp December 7, 2017 16:13
@caisq caisq self-assigned this Dec 7, 2017
@caisq caisq added the awaiting review Pull request awaiting review label Dec 7, 2017
Copy link
Contributor

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

Sorry for the slow response: I missed the notification email about this change.

Could you also please verify that the new cases are tested by tensorflow/compiler/tests/tensor_array_ops_test.py, and extend the tests to cover them if not?

std::vector<int64> const_indices;
Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
if (status.ok()) {
bool is_simple_gather = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be more readable if you added a comment defining what a "simple" gather is.

Perhaps use a more descriptive name, maybe "gather_is_dense_slice"?


std::vector<int64> const_indices;
Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
if (status.ok() && num_indices==value_shape.dim_size(0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Add a space before and after "==" for consistency.

@@ -352,30 +376,50 @@ class TensorArrayScatterOp : public XlaOpKernel {
const xla::ComputationDataHandle value = ctx->Input(2);
const xla::ComputationDataHandle flow = ctx->Input(3);

auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;
bool is_simple = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Add a comment describing what "simple" means here.


if (is_simple) {
ta = b->Add(ta, value);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Remove blank line.

}


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove extra blank line.

@DavidNorman
Copy link
Contributor Author

thanks for the comments. I will check the coverage.

I'm not sure that the tensor_array_ops_test actually tests a lot of XLA device side stuff. Last time I looked it was mostly being compiled by the CPU constant removal code. this change certainly passes the tests in that file, but maybe only because the HLO graphs are optimized down to constants before they are compiled.

let me get back to you tomorrow....

@DavidNorman
Copy link
Contributor Author

good news. both the scatter and gather changes are hit by the tensor_array_ops_test set of tests.

@DavidNorman
Copy link
Contributor Author

i think i'll ignore those last 2 comments 😆

Copy link
Contributor

@hawkinsp hawkinsp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@hawkinsp hawkinsp added awaiting testing (then merge) and removed awaiting review Pull request awaiting review labels Dec 13, 2017
@caisq
Copy link
Contributor

caisq commented Dec 14, 2017

@tensorflow-jenkins test this please

@caisq caisq added the kokoro:force-run Tests on submitted change label Dec 14, 2017
@kokoro-team kokoro-team removed the kokoro:force-run Tests on submitted change label Dec 14, 2017
@caisq caisq merged commit 2bb302e into tensorflow:master Dec 14, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants