Skip to content
Permalink
Browse files Browse the repository at this point in the history
Add out of bounds array check to dynamic_stitch_op.
PiperOrigin-RevId: 506418249
  • Loading branch information
changm authored and tensorflower-gardener committed Feb 1, 2023
1 parent e66027f commit ee004b1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
4 changes: 4 additions & 0 deletions tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc
Expand Up @@ -146,6 +146,10 @@ class DynamicStitchOp : public XlaOpKernel {
for (int input_num = 0; input_num < indices.size(); input_num++) {
for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) {
int index = indices[input_num].Get<int>({i});
OP_REQUIRES(
ctx, index >= 0,
errors::InvalidArgument("indices[", index, "] is out of range"));

src_input_vector[index] = input_num;
src_slice_vector[index] = i;
if (!src_index_used[index]) {
Expand Down
17 changes: 11 additions & 6 deletions tensorflow/core/kernels/dynamic_stitch_op.cc
Expand Up @@ -97,6 +97,17 @@ class DynamicStitchOpImplBase : public OpKernel {

*first_dim_size = max_index + 1;

for (const Tensor& indices : *indices_inputs) {
auto indices_vec = indices.flat<int32>();

for (int i = 0; i < indices_vec.size(); i++) {
int32_t index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
c, FastBoundsCheck(index, *first_dim_size),
errors::InvalidArgument("indices[", i, "] is out of range"));
}
}

// Validate that data[i].shape = indices[i].shape + constant
OP_REQUIRES_OK(c, c->input_list("data", data_inputs));
const Tensor& data0 = (*data_inputs)[0];
Expand Down Expand Up @@ -265,9 +276,6 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
const T* data_base = data_flat.data();
for (int i = 0; i < indices_vec.size(); i++) {
int32_t index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
c, FastBoundsCheck(index, first_dim_size),
errors::InvalidArgument("indices[", i, "] is out of range"));
memcpy(merged_base + index * slice_size, data_base + i * slice_size,
slice_bytes);
}
Expand All @@ -277,9 +285,6 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
// Copy slice data[i] to merged[indices[i]]
Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
int32_t index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
c, FastBoundsCheck(index, first_dim_size),
errors::InvalidArgument("indices[", i, "] is out of range"));
Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0);
merged_flat.slice(merged_indices, sizes) =
data_flat.slice(data_indices, sizes);
Expand Down
Expand Up @@ -226,6 +226,19 @@ def testErrorDataAndIndicesSizeMismatch(self):
with self.assertRaises(ValueError):
self.stitch_op(indices, data)

def testOutOfBoundsIndexRaisesInvalidArgument(self):
with self.assertRaisesRegex(errors.InvalidArgumentError, "out of range"):
indices = [[-1000], [405], [519], [758], [1015]]
data = [
[110.27793884277344],
[120.29475402832031],
[157.2418212890625],
[157.2626953125],
[188.45382690429688],
]

self.evaluate(self.stitch_op(indices, data))


class DynamicStitchTest(DynamicStitchTestBase, test.TestCase):

Expand Down

0 comments on commit ee004b1

Please sign in to comment.