Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent heap OOB read in TFLite's gather.cc.
Passing negative indices is illegal but there was a missing check so that resulted in OOB accesses.

PiperOrigin-RevId: 387231300
Change-Id: I3111b54b2f232638d795be17efc46abe4ede6bf8
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 28, 2021
1 parent ac72971 commit eb92112
Showing 1 changed file with 53 additions and 16 deletions.
69 changes: 53 additions & 16 deletions tensorflow/lite/kernels/gather.cc
Expand Up @@ -117,8 +117,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

template <typename InputT, typename PositionsT>
TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input,
const TfLiteTensor* positions, TfLiteTensor* output) {
TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params,
const TfLiteTensor* input, const TfLiteTensor* positions,
TfLiteTensor* output) {
const PositionsT* indexes = GetTensorData<PositionsT>(positions);
bool indices_has_only_positive_elements = true;
const size_t num_indices = positions->bytes / sizeof(PositionsT);
for (size_t i = 0; i < num_indices; i++) {
if (indexes[i] < 0) {
indices_has_only_positive_elements = false;
break;
}
}
TF_LITE_ENSURE(context, indices_has_only_positive_elements);

tflite::GatherParams op_params;
op_params.axis = params.axis;
op_params.batch_dims = params.batch_dims;
Expand All @@ -134,7 +146,18 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
const TfLiteTensor* positions,
TfLiteTensor* output) {
DynamicBuffer buffer;

const PositionT* indexes = GetTensorData<PositionT>(positions);
bool indices_has_only_positive_elements = true;
const size_t num_indices = positions->bytes / sizeof(PositionT);
for (size_t i = 0; i < num_indices; i++) {
if (indexes[i] < 0) {
indices_has_only_positive_elements = false;
break;
}
}
TF_LITE_ENSURE(context, indices_has_only_positive_elements);

const PositionT num_strings = GetStringCount(input);
const int num_indexes = NumElements(positions);

Expand Down Expand Up @@ -163,19 +186,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (positions->type == kTfLiteInt32) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int32_t>(*params, input, positions, output);
return Gather<float, int32_t>(context, *params, input, positions,
output);
case kTfLiteUInt8:
return Gather<uint8_t, int32_t>(*params, input, positions, output);
return Gather<uint8_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt8:
return Gather<int8_t, int32_t>(*params, input, positions, output);
return Gather<int8_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt16:
return Gather<int16_t, int32_t>(*params, input, positions, output);
return Gather<int16_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt32:
return Gather<int32_t, int32_t>(*params, input, positions, output);
return Gather<int32_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt64:
return Gather<int64_t, int32_t>(*params, input, positions, output);
return Gather<int64_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteBool:
return Gather<bool, int32_t>(*params, input, positions, output);
return Gather<bool, int32_t>(context, *params, input, positions,
output);
case kTfLiteString:
return GatherStrings<int32_t>(context, input, positions, output);
default:
Expand All @@ -187,19 +217,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (positions->type == kTfLiteInt64) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int64_t>(*params, input, positions, output);
return Gather<float, int64_t>(context, *params, input, positions,
output);
case kTfLiteUInt8:
return Gather<uint8_t, int64_t>(*params, input, positions, output);
return Gather<uint8_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt8:
return Gather<int8_t, int64_t>(*params, input, positions, output);
return Gather<int8_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt16:
return Gather<int16_t, int64_t>(*params, input, positions, output);
return Gather<int16_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt32:
return Gather<int32_t, int64_t>(*params, input, positions, output);
return Gather<int32_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt64:
return Gather<int64_t, int64_t>(*params, input, positions, output);
return Gather<int64_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteBool:
return Gather<bool, int64_t>(*params, input, positions, output);
return Gather<bool, int64_t>(context, *params, input, positions,
output);
case kTfLiteString:
return GatherStrings<int64_t>(context, input, positions, output);
default:
Expand Down

0 comments on commit eb92112

Please sign in to comment.