Skip to content
Permalink
Browse files Browse the repository at this point in the history
GatherNd verifies that an index is valid before reading. (#1286)
  • Loading branch information
alankelly committed Jul 26, 2022
1 parent c1dbef3 commit 4142e47
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
15 changes: 12 additions & 3 deletions tensorflow/lite/micro/kernels/gather_nd.cc
Expand Up @@ -131,7 +131,8 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
slice_size *= params->dims->data[i];
}

int remain_flat_size = ElementCount(*params->dims);
int params_flat_size = ElementCount(*params->dims);
int remain_flat_size = params_flat_size;

// Number of elements per dimension
int dims_to_count[MAX_INDICES_ND];
Expand All @@ -147,6 +148,9 @@ TfLiteStatus GatherNd(const TfLiteEvalTensor* params,
IndicesT index = index_data[offset];
from_pos += index * dims_to_count[j];
}
if (from_pos < 0 || from_pos + slice_size > params_flat_size) {
return kTfLiteError;
}
std::memcpy(output_data + i * slice_size, param_data + from_pos,
sizeof(ParamsT) * slice_size);
}
Expand All @@ -158,19 +162,24 @@ TfLiteStatus EvalGatherNd(TfLiteContext* context,
const TfLiteEvalTensor* params,
const TfLiteEvalTensor* indices,
TfLiteEvalTensor* output) {
TfLiteStatus status = kTfLiteError;
switch (params->type) {
case kTfLiteFloat32:
return GatherNd<float, IndicesT>(params, indices, output);
status = GatherNd<float, IndicesT>(params, indices, output);
break;
case kTfLiteInt8:
return GatherNd<int8_t, IndicesT>(params, indices, output);
status = GatherNd<int8_t, IndicesT>(params, indices, output);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Params type '%s' are not supported by gather_nd.",
TfLiteTypeGetName(params->type));
return kTfLiteError;
}
if (status != kTfLiteOk) {
TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
}
return status;
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
Expand Down
48 changes: 39 additions & 9 deletions tensorflow/lite/micro/kernels/gather_nd_test.cc
Expand Up @@ -26,8 +26,8 @@ namespace {
template <typename ParamType, typename IndexType>
void TestGatherNd(int* param_dims, const ParamType* param_data, int* index_dims,
const IndexType* index_data, int* output_dims,
ParamType* output_data,
const ParamType* expected_output_data) {
ParamType* output_data, const ParamType* expected_output_data,
const TfLiteStatus expected_status = kTfLiteOk) {
TfLiteIntArray* pdims = IntArrayFromInts(param_dims);
TfLiteIntArray* idims = IntArrayFromInts(index_dims);
TfLiteIntArray* odims = IntArrayFromInts(output_dims);
Expand All @@ -49,14 +49,16 @@ void TestGatherNd(int* param_dims, const ParamType* param_data, int* index_dims,
micro::KernelRunner runner(registration, tensors, tensors_size, inputs_array,
outputs_array, /*builtin_data=*/nullptr);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.InitAndPrepare());
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, runner.Invoke());
TF_LITE_MICRO_EXPECT_EQ(expected_status, runner.Invoke());

// The output tensor's data and shape have been updated by the kernel.
TfLiteTensor* actual_output_tensor = &tensors[2];
TfLiteIntArray* actual_output_dims = actual_output_tensor->dims;
const int output_size = ElementCount(*actual_output_dims);
for (int i = 0; i < output_size; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
if (expected_status == kTfLiteOk) {
// The output tensor's data and shape have been updated by the kernel.
TfLiteTensor* actual_output_tensor = &tensors[2];
TfLiteIntArray* actual_output_dims = actual_output_tensor->dims;
const int output_size = ElementCount(*actual_output_dims);
for (int i = 0; i < output_size; ++i) {
TF_LITE_MICRO_EXPECT_EQ(expected_output_data[i], output_data[i]);
}
}
}

Expand Down Expand Up @@ -298,4 +300,32 @@ TF_LITE_MICRO_TEST(GatherNd_Int8Int32) {
golden_data);
}

TF_LITE_MICRO_TEST(GatherNd_ReadOOB) {
// For input_dims[], index_dims[], or output_dims[], element 0 is the
// number of dimensions in that array, not the actual dimension data.
int input_dims[] = {2, 2, 2};
int index_dims[] = {2, 2, 2};
const int32_t index_data[] = {0, 1, 2, 0};
const int8_t input_data[] = {1, -1, 1, -2};
int8_t output_data;
int output_dims[] = {1, 0, 0};
tflite::testing::TestGatherNd<int8_t, int32_t>(
input_dims, input_data, index_dims, index_data, output_dims, &output_data,
nullptr, kTfLiteError);
}

TF_LITE_MICRO_TEST(GatherNd_ReadOOBNegative) {
// For input_dims[], index_dims[], or output_dims[], element 0 is the
// number of dimensions in that array, not the actual dimension data.
int input_dims[] = {2, 2, 2};
int index_dims[] = {2, 2, 2};
const int32_t index_data[] = {0, -1, 1, 0};
const int8_t input_data[] = {1, -1, 1, -2};
int8_t output_data;
int output_dims[] = {1, 0, 0};
tflite::testing::TestGatherNd<int8_t, int32_t>(
input_dims, input_data, index_dims, index_data, output_dims, &output_data,
nullptr, kTfLiteError);
}

TF_LITE_MICRO_TESTS_END

0 comments on commit 4142e47

Please sign in to comment.