Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent heap OOB read in TFLite's gather.cc. #51018

Merged
merged 1 commit into from Jul 30, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
69 changes: 53 additions & 16 deletions tensorflow/lite/kernels/gather.cc
Expand Up @@ -117,8 +117,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}

template <typename InputT, typename PositionsT>
TfLiteStatus Gather(const TfLiteGatherParams& params, const TfLiteTensor* input,
const TfLiteTensor* positions, TfLiteTensor* output) {
TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params,
const TfLiteTensor* input, const TfLiteTensor* positions,
TfLiteTensor* output) {
const PositionsT* indexes = GetTensorData<PositionsT>(positions);
bool indices_has_only_positive_elements = true;
const size_t num_indices = positions->bytes / sizeof(PositionsT);
for (size_t i = 0; i < num_indices; i++) {
if (indexes[i] < 0) {
indices_has_only_positive_elements = false;
break;
}
}
TF_LITE_ENSURE(context, indices_has_only_positive_elements);

tflite::GatherParams op_params;
op_params.axis = params.axis;
op_params.batch_dims = params.batch_dims;
Expand All @@ -134,7 +146,18 @@ TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
const TfLiteTensor* positions,
TfLiteTensor* output) {
DynamicBuffer buffer;

const PositionT* indexes = GetTensorData<PositionT>(positions);
bool indices_has_only_positive_elements = true;
const size_t num_indices = positions->bytes / sizeof(PositionT);
for (size_t i = 0; i < num_indices; i++) {
if (indexes[i] < 0) {
indices_has_only_positive_elements = false;
break;
}
}
TF_LITE_ENSURE(context, indices_has_only_positive_elements);

const PositionT num_strings = GetStringCount(input);
const int num_indexes = NumElements(positions);

Expand Down Expand Up @@ -163,19 +186,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (positions->type == kTfLiteInt32) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int32_t>(*params, input, positions, output);
return Gather<float, int32_t>(context, *params, input, positions,
output);
case kTfLiteUInt8:
return Gather<uint8_t, int32_t>(*params, input, positions, output);
return Gather<uint8_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt8:
return Gather<int8_t, int32_t>(*params, input, positions, output);
return Gather<int8_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt16:
return Gather<int16_t, int32_t>(*params, input, positions, output);
return Gather<int16_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt32:
return Gather<int32_t, int32_t>(*params, input, positions, output);
return Gather<int32_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteInt64:
return Gather<int64_t, int32_t>(*params, input, positions, output);
return Gather<int64_t, int32_t>(context, *params, input, positions,
output);
case kTfLiteBool:
return Gather<bool, int32_t>(*params, input, positions, output);
return Gather<bool, int32_t>(context, *params, input, positions,
output);
case kTfLiteString:
return GatherStrings<int32_t>(context, input, positions, output);
default:
Expand All @@ -187,19 +217,26 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (positions->type == kTfLiteInt64) {
switch (input->type) {
case kTfLiteFloat32:
return Gather<float, int64_t>(*params, input, positions, output);
return Gather<float, int64_t>(context, *params, input, positions,
output);
case kTfLiteUInt8:
return Gather<uint8_t, int64_t>(*params, input, positions, output);
return Gather<uint8_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt8:
return Gather<int8_t, int64_t>(*params, input, positions, output);
return Gather<int8_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt16:
return Gather<int16_t, int64_t>(*params, input, positions, output);
return Gather<int16_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt32:
return Gather<int32_t, int64_t>(*params, input, positions, output);
return Gather<int32_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteInt64:
return Gather<int64_t, int64_t>(*params, input, positions, output);
return Gather<int64_t, int64_t>(context, *params, input, positions,
output);
case kTfLiteBool:
return Gather<bool, int64_t>(*params, input, positions, output);
return Gather<bool, int64_t>(context, *params, input, positions,
output);
case kTfLiteString:
return GatherStrings<int64_t>(context, input, positions, output);
default:
Expand Down