Skip to content
Permalink
Browse files Browse the repository at this point in the history
Return a TFLite error if gather_nd will result in reading invalid memory
PiperOrigin-RevId: 463054033
  • Loading branch information
tensorflower-gardener committed Jul 25, 2022
1 parent b2df2de commit 595a65a
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 16 deletions.
24 changes: 14 additions & 10 deletions tensorflow/lite/kernels/gather_nd.cc
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <stdint.h>

#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
Expand Down Expand Up @@ -102,13 +103,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

template <typename ParamsT, typename IndicesT>
TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
TfLiteTensor* output) {
reference_ops::GatherNd(
TfLiteStatus GatherNd(TfLiteContext* context, const TfLiteTensor* params,
const TfLiteTensor* indices, TfLiteTensor* output) {
const TfLiteStatus status = reference_ops::GatherNd(
GetTensorShape(params), GetTensorData<ParamsT>(params),
GetTensorShape(indices), GetTensorData<IndicesT>(indices),
GetTensorShape(output), GetTensorData<ParamsT>(output));
return kTfLiteOk;
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
}
return status;
}

template <typename IndicesT>
Expand Down Expand Up @@ -136,17 +140,17 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,

switch (params->type) {
case kTfLiteFloat32:
return GatherNd<float, IndicesT>(params, indices, output);
return GatherNd<float, IndicesT>(context, params, indices, output);
case kTfLiteUInt8:
return GatherNd<uint8_t, IndicesT>(params, indices, output);
return GatherNd<uint8_t, IndicesT>(context, params, indices, output);
case kTfLiteInt8:
return GatherNd<int8_t, IndicesT>(params, indices, output);
return GatherNd<int8_t, IndicesT>(context, params, indices, output);
case kTfLiteInt16:
return GatherNd<int16_t, IndicesT>(params, indices, output);
return GatherNd<int16_t, IndicesT>(context, params, indices, output);
case kTfLiteInt32:
return GatherNd<int32_t, IndicesT>(params, indices, output);
return GatherNd<int32_t, IndicesT>(context, params, indices, output);
case kTfLiteInt64:
return GatherNd<int64_t, IndicesT>(params, indices, output);
return GatherNd<int64_t, IndicesT>(context, params, indices, output);
case kTfLiteString:
return GatherNdString<IndicesT>(params, indices, output);
default:
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/lite/kernels/gather_nd_test.cc
Expand Up @@ -73,6 +73,22 @@ TEST(GatherNdOpTest, ElementIndexingIntoMatrix) {
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray({1.1, 2.2}));
}

TEST(GatherNdOpTest, ErrorOnOutOfBoundsTooLarge) {
GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 2}});
m.SetInput<float>({1.1, 1.2, 2.1, 2.2});
m.SetPositions<int32_t>({0, 0, 2, 0});
EXPECT_EQ(m.Invoke(), kTfLiteError);
m.SetPositions<int32_t>({0, 0, 1, 2});
EXPECT_EQ(m.Invoke(), kTfLiteError);
}

TEST(GatherNdOpTest, ErrorOnOutOfBoundsNegative) {
GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 2}});
m.SetInput<float>({1.1, 1.2, 2.1, 2.2});
m.SetPositions<int32_t>({1, -1, 1, 1});
EXPECT_EQ(m.Invoke(), kTfLiteError);
}

TEST(GatherNdOpTest, SliceIndexingIntoMatrix) {
GatherNdOpModel m({TensorType_FLOAT32, {2, 2}}, {TensorType_INT32, {2, 1}});
m.SetInput<float>({1.1, 1.2, 2.1, 2.2});
Expand Down
21 changes: 15 additions & 6 deletions tensorflow/lite/kernels/internal/reference/reference_ops.h
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "fixedpoint/fixedpoint.h"
#include "ruy/profiler/instrumentation.h" // from @ruy
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
Expand Down Expand Up @@ -595,23 +596,31 @@ inline GatherNdHelperResult GatherNdHelper(const RuntimeShape& params_shape,
return ret;
}

// Implements GatherNd.
// Returns an error if any of the indices_data would cause an out of bounds
// memory read.
template <typename ParamsT, typename IndicesT = int32>
inline void GatherNd(const RuntimeShape& params_shape,
const ParamsT* params_data,
const RuntimeShape& indices_shape,
const IndicesT* indices_data,
const RuntimeShape& output_shape, ParamsT* output_data) {
inline TfLiteStatus GatherNd(const RuntimeShape& params_shape,
const ParamsT* params_data,
const RuntimeShape& indices_shape,
const IndicesT* indices_data,
const RuntimeShape& output_shape,
ParamsT* output_data) {
ruy::profiler::ScopeLabel label("GatherNd");

const GatherNdHelperResult res = GatherNdHelper(params_shape, indices_shape);
for (int i = 0; i < res.n_slices; ++i) {
int from_pos = 0;
int64_t from_pos = 0;
for (int j = 0; j < res.indices_nd; ++j) {
from_pos += indices_data[i * res.indices_nd + j] * res.dims_to_count[j];
}
if (from_pos < 0 || from_pos + res.slice_size > params_shape.FlatSize()) {
return kTfLiteError;
}
std::memcpy(output_data + i * res.slice_size, params_data + from_pos,
sizeof(ParamsT) * res.slice_size);
}
return kTfLiteOk;
}

#ifndef TF_LITE_STATIC_MEMORY
Expand Down

0 comments on commit 595a65a

Please sign in to comment.