Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Add fast path cases for common scatter and gather operations #15185

Merged
merged 2 commits into from
Dec 14, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
87 changes: 66 additions & 21 deletions tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,32 @@ class TensorArrayGatherOp : public XlaOpKernel {

xla::ComputationDataHandle ta = resource->value;

// Look for the case where the gather takes a simple slice from the
// tensor array (0, 1, 2, 3, 4, ..., N)
std::vector<int64> const_indices;
Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
if (status.ok()) {
bool gather_is_dense_slice = true;
for (auto i = 0; i < const_indices.size(); i++) {
if (const_indices[i] != i) {
gather_is_dense_slice = false;
break;
}
}

if (gather_is_dense_slice) {
std::vector<int64> begin(ta_shape.dims(), 0);
std::vector<int64> strides(ta_shape.dims(), 1);
std::vector<int64> end(ta_shape.dims(), 1);
end[0] = const_indices.size();
for (auto i = 1; i < ta_shape.dims(); i++) {
end[i] = ta_shape.dim_size(i);
}
ctx->SetOutput(0, b->Slice(ta, begin, end, strides));
return;
}
}

xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b);
ctx->SetOutput(0, gather);
Expand Down Expand Up @@ -352,28 +378,47 @@ class TensorArrayScatterOp : public XlaOpKernel {
const xla::ComputationDataHandle value = ctx->Input(2);
const xla::ComputationDataHandle flow = ctx->Input(3);

auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;

std::vector<int64> value_starts(value_shape.dims(), 0);
auto value_ends = value_shape.dim_sizes();

std::vector<int64> value_strides(value_shape.dims(), 1);

// For every (index, value) pair, update the corresponding TensorArray
// storage.
for (int i = 0; i < num_indices; ++i) {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
auto slice = b->Slice(value, value_starts, value_ends, value_strides);
// Look for the case where the scatter is for each sub-tensor in order. The
// tensor array implementation allows for this to be a straight addition.
bool scatter_all_elements_in_order = false;
std::vector<int64> const_indices;
Status status = ctx->ConstantInputAsIntVector(1, &const_indices);
if (status.ok() && num_indices == value_shape.dim_size(0)) {
scatter_all_elements_in_order = true;
for (auto i = 0; i < num_indices; i++) {
if (const_indices[i] != i) {
scatter_all_elements_in_order = false;
break;
}
}
}

// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
if (scatter_all_elements_in_order) {
ta = b->Add(ta, value);
} else {
auto slice_dims = value_shape.dim_sizes();
slice_dims[0] = 1LL;

std::vector<int64> value_starts(value_shape.dims(), 0);
auto value_ends = value_shape.dim_sizes();

std::vector<int64> value_strides(value_shape.dims(), 1);

// For every (index, value) pair, update the corresponding TensorArray
// storage.
for (int i = 0; i < num_indices; ++i) {
// Slice out part of the value.
value_starts[0] = i;
value_ends[0] = i + 1;
auto slice = b->Slice(value, value_starts, value_ends, value_strides);

// start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0].
auto index = b->Slice(indices, {i}, {i + 1}, {1});
auto start_indices =
b->Pad(b->Reshape(index, {1}), b->ConstantR0<int32>(0),
xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}}));
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
}

resource->value = ta;
Expand Down