-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[XLA] Add fast path cases for common scatter and gather operations #15185
[XLA] Add fast path cases for common scatter and gather operations #15185
Conversation
Can one of the admins verify this patch? |
although there are existing unit tests that catch this change, do you think I should add an explicit device-targetted test which demonstrates this working. It is hard to check that the fast path has been used, without checking the form of the XLA HLO graph - but this could be done in the XLA CC tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Sorry for the slow response: I missed the notification email about this change.
Could you also please verify that the new cases are tested by tensorflow/compiler/tests/tensor_array_ops_test.py, and extend the tests to cover them if not?
std::vector<int64> const_indices; | ||
Status status = ctx->ConstantInputAsIntVector(1, &const_indices); | ||
if (status.ok()) { | ||
bool is_simple_gather = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be more readable if you added a comment defining what a "simple" gather is.
Perhaps use a more descriptive name, maybe "gather_is_dense_slice"?
|
||
std::vector<int64> const_indices; | ||
Status status = ctx->ConstantInputAsIntVector(1, &const_indices); | ||
if (status.ok() && num_indices==value_shape.dim_size(0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Add a space before and after "==" for consistency.
@@ -352,30 +376,50 @@ class TensorArrayScatterOp : public XlaOpKernel { | |||
const xla::ComputationDataHandle value = ctx->Input(2); | |||
const xla::ComputationDataHandle flow = ctx->Input(3); | |||
|
|||
auto slice_dims = value_shape.dim_sizes(); | |||
slice_dims[0] = 1LL; | |||
bool is_simple = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. Add a comment describing what "simple" means here.
|
||
if (is_simple) { | ||
ta = b->Add(ta, value); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Remove blank line.
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove extra blank line.
thanks for the comments. I will check the coverage. I'm not sure that the tensor_array_ops_test actually tests a lot of XLA device side stuff. Last time I looked it was mostly being compiled by the CPU constant removal code. this change certainly passes the tests in that file, but maybe only because the HLO graphs are optimized down to constants before they are compiled. let me get back to you tomorrow.... |
good news. both the scatter and gather changes are hit by the tensor_array_ops_test set of tests. |
i think i'll ignore those last 2 comments 😆 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
@tensorflow-jenkins test this please |
This change checks if the indices vector passed to a scatter or gather operation is a constant, and does a fast-path operation when it is filled with a zero-based incrementing set.
This is quite a common case because of tensor-array stack and unstack.