Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.12.0-rc2 cherry-pick request: Various XLA scatter improvements. #23235

Merged
merged 9 commits into from
Oct 25, 2018
Merged

1.12.0-rc2 cherry-pick request: Various XLA scatter improvements. #23235

merged 9 commits into from
Oct 25, 2018

Conversation

tatatodd
Copy link
Contributor

There are various piper origin CLs cherrypicked into this PR:

PiperOrigin-RevId: 215687800
PiperOrigin-RevId: 216412467
PiperOrigin-RevId: 216437329
PiperOrigin-RevId: 216448063
PiperOrigin-RevId: 216624225
PiperOrigin-RevId: 216798034
PiperOrigin-RevId: 216921512
PiperOrigin-RevId: 216968475

This simple has a kernel that runs on every element of the updates tensor,
figure out the right indices to perform the update, and applies it with an
atomic operation.

Currently we emit a CAS for plain (i.e. non-add) updates, which is inefficient.
Also TuplePointsToAnalysis doesn't know that it should alias the operand and
output buffers of a scatter, which would avoid a copy.

PiperOrigin-RevId: 216412467
This avoids a copy.

PiperOrigin-RevId: 216437329
We have a 1-element thunk sequence if we're not copying. That's still two
thunks and hlo profiling gets confused if it sees two thunks for the same
instruction and one of them claims to be the whole instruction.

PiperOrigin-RevId: 216448063
We fuse everything into the scatter now, and emit two kernels. The first kernel
fills the output buffer with the computation fused into the scatter operand.
The second kernel is a regular scatter, which also contains the fused
operations from the updates and scatter_indices inputs.

PiperOrigin-RevId: 216624225
This was comparing the index after adding it to the window, and then comparing
against the window dimension. This means that the bounds check was only correct
for the first element of a window. Instead compare the scatter index, which is
the same for all elements of a window.

PiperOrigin-RevId: 216921512
The tuple buffer is never read, so stop emitting code to fill it. A typical
root tuple consists of a H2D memcpy and a host callback, both of which are
somewhat slow.

This helps tiny models and inference benchmarks, where the host/device syncs
can be a significant part of the runtime of the entire computation.

PiperOrigin-RevId: 216968475
@tatatodd
Copy link
Contributor Author

I'm ignoring the clang-format check, since it's suggesting weird formatting changes, and it's not critical anyways.

@tatatodd tatatodd merged commit e72c9eb into tensorflow:r1.12 Oct 25, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants