Skip to content
Permalink
Browse files Browse the repository at this point in the history
Check bounds for reads and writes in scatter_nd
PiperOrigin-RevId: 463365479
  • Loading branch information
alankelly authored and tensorflower-gardener committed Jul 26, 2022
1 parent b6d1794 commit b4d4b4c
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
19 changes: 13 additions & 6 deletions tensorflow/lite/kernels/internal/reference/reference_ops.h
Expand Up @@ -656,11 +656,12 @@ inline TfLiteStatus GatherNdString(const RuntimeShape& params_shape,
#endif

template <typename IndicesT, typename UpdatesT>
inline void ScatterNd(const RuntimeShape& indices_shape,
const IndicesT* indices_data,
const RuntimeShape& updates_shape,
const UpdatesT* updates_data,
const RuntimeShape& output_shape, UpdatesT* output_data) {
inline TfLiteStatus ScatterNd(const RuntimeShape& indices_shape,
const IndicesT* indices_data,
const RuntimeShape& updates_shape,
const UpdatesT* updates_data,
const RuntimeShape& output_shape,
UpdatesT* output_data) {
ruy::profiler::ScopeLabel label("ScatterNd");

int n_slices = 1;
Expand All @@ -683,18 +684,24 @@ inline void ScatterNd(const RuntimeShape& indices_shape,
remain_flat_size = dims_to_count[i];
}

if (n_slices * slice_size > updates_shape.FlatSize()) {
return kTfLiteError;
}
memset(output_data, 0, sizeof(UpdatesT) * output_flat_size);
for (int i = 0; i < n_slices; ++i) {
int to_pos = 0;
for (int j = 0; j < indices_nd; ++j) {
IndicesT idx = indices_data[i * indices_nd + j];
TFLITE_DCHECK(0 <= idx && idx < output_shape.Dims(j));
to_pos += idx * dims_to_count[j];
}
if (to_pos < 0 || to_pos + slice_size > output_flat_size) {
return kTfLiteError;
}
for (int j = 0; j < slice_size; j++) {
output_data[to_pos + j] += updates_data[i * slice_size + j];
}
}
return kTfLiteOk;
}

template <typename T>
Expand Down
26 changes: 18 additions & 8 deletions tensorflow/lite/kernels/scatter_nd.cc
Expand Up @@ -129,11 +129,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
template <typename IndicesT, typename UpdatesT>
TfLiteStatus ScatterNd(const TfLiteTensor* indices, const TfLiteTensor* updates,
TfLiteTensor* output) {
reference_ops::ScatterNd(
return reference_ops::ScatterNd(
GetTensorShape(indices), GetTensorData<IndicesT>(indices),
GetTensorShape(updates), GetTensorData<UpdatesT>(updates),
GetTensorShape(output), GetTensorData<UpdatesT>(output));
return kTfLiteOk;
}

template <typename IndicesT>
Expand All @@ -149,25 +148,36 @@ TfLiteStatus EvalScatterNd(TfLiteContext* context, const TfLiteTensor* indices,
ResizeOutputTensor<IndicesT>(context, shape, output));
}

TfLiteStatus status = kTfLiteError;
switch (updates->type) {
case kTfLiteFloat32:
return ScatterNd<IndicesT, float>(indices, updates, output);
status = ScatterNd<IndicesT, float>(indices, updates, output);
break;
case kTfLiteUInt8:
return ScatterNd<IndicesT, uint8_t>(indices, updates, output);
status = ScatterNd<IndicesT, uint8_t>(indices, updates, output);
break;
case kTfLiteBool:
return ScatterNd<IndicesT, bool>(indices, updates, output);
status = ScatterNd<IndicesT, bool>(indices, updates, output);
break;
case kTfLiteInt8:
return ScatterNd<IndicesT, int8_t>(indices, updates, output);
status = ScatterNd<IndicesT, int8_t>(indices, updates, output);
break;
case kTfLiteInt32:
return ScatterNd<IndicesT, int32_t>(indices, updates, output);
status = ScatterNd<IndicesT, int32_t>(indices, updates, output);
break;
case kTfLiteInt64:
return ScatterNd<IndicesT, int64_t>(indices, updates, output);
status = ScatterNd<IndicesT, int64_t>(indices, updates, output);
break;
default:
TF_LITE_KERNEL_LOG(
context, "Updates of type '%s' are not supported by scatter_nd.",
TfLiteTypeGetName(updates->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "scatter_nd index out of bounds");
}
return status;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Expand Down
29 changes: 29 additions & 0 deletions tensorflow/lite/kernels/scatter_nd_test.cc
Expand Up @@ -361,5 +361,34 @@ TEST(ScatterNdOpTest, DynamicShape) {
/*2, 3*/ 1, 2, 3, 4, 5}));
}

TEST(ScatterNdOpTest, ReadAndWriteArrayLimits) {
ScatterNdOpModel m({TensorType_INT32, {5, 1}}, {TensorType_INT32, {5}},
{TensorType_INT32, {1}});
m.SetIndices<int32_t>({4, 3, 1, 0, 2});
m.SetUpdates<int32_t>({1, 2, 3, 7, 9});
m.SetShape<int32_t>({5});
ASSERT_EQ(m.Invoke(), kTfLiteOk);
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5}));
EXPECT_THAT(m.GetOutput<int32_t>(), ElementsAreArray({7, 3, 9, 2, 1}));
}

TEST(ScatterNdOpTest, OOBRead) {
ScatterNdOpModel m({TensorType_INT32, {1, 1}}, {TensorType_INT32, {1}},
{TensorType_INT32, {1}});
m.SetIndices<int32_t>({4});
m.SetUpdates<int32_t>({1});
m.SetShape<int32_t>({1});
ASSERT_EQ(m.Invoke(), kTfLiteError);
}

TEST(ScatterNdOpTest, OOBWrites) {
ScatterNdOpModel m({TensorType_INT32, {5, 1}}, {TensorType_INT32, {5}},
{TensorType_INT32, {1}});
m.SetIndices<int32_t>({4, 3, 1, -0x38, 0x38});
m.SetUpdates<int32_t>({1, 2, 3, 0x44444444, 0x55555555});
m.SetShape<int32_t>({1});
ASSERT_EQ(m.Invoke(), kTfLiteError);
}

} // namespace
} // namespace tflite

0 comments on commit b4d4b4c

Please sign in to comment.