Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tflite] Test for kTfLiteOptionalTensor in GetInput.
`GetInput`, `GetVariableInput` and `GetOutput` all fail to check for the case where `node->inputs->data[index]` is the special `kTfLiteOptionalTensor` value (-1) which then causes `context->tensors[node->inputs->data[index]]` to read from invalid memory location.

This fix makes `GetInput` and related return `nullptr` in those cases, asking the caller to check for `nullptr`. This is better than having `GetOptionalInputTensor` and `GetOptionalOutputTensor` (does not exist but could be added) as using the patched `GetInput` in error would be caught by a sanitizer test in the default optimized build (due to the `-fsanitize=null` option).

PiperOrigin-RevId: 332512190
Change-Id: Iabca54da2f2de02b6ece3c38b54f76d4277d689e
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 18, 2020
1 parent 204945b commit 46d5b08
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
28 changes: 20 additions & 8 deletions tensorflow/lite/kernels/kernel_util.cc
Expand Up @@ -32,11 +32,17 @@ namespace {

inline TfLiteTensor* GetMutableInput(const TfLiteContext* context,
const TfLiteNode* node, int index) {
if (context->tensors != nullptr) {
return &context->tensors[node->inputs->data[index]];
} else {
return context->GetTensor(context, node->inputs->data[index]);
if (index >= 0 && index < node->inputs->size) {
const int tensor_index = node->inputs->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
}
return nullptr;
}

} // anonymous namespace.
Expand All @@ -54,11 +60,17 @@ TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,

TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
if (context->tensors != nullptr) {
return &context->tensors[node->outputs->data[index]];
} else {
return context->GetTensor(context, node->outputs->data[index]);
if (index >= 0 && index < node->outputs->size) {
const int tensor_index = node->outputs->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
} else {
return context->GetTensor(context, tensor_index);
}
}
}
return nullptr;
}

const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
Expand Down
68 changes: 64 additions & 4 deletions tensorflow/lite/kernels/kernel_util.h
Expand Up @@ -29,18 +29,46 @@ namespace tflite {
// benchmark_model for MobileNet + MobileBERT is unaffected. If such a change is
// made, move the newly non-inlined function declarations to the top of this
// header file.

// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
const TfLiteTensor* GetInput(const TfLiteContext* context,
const TfLiteNode* node, int index);

// Note: You must check if result is not null:
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
TfLiteTensor* GetVariableInput(TfLiteContext* context, const TfLiteNode* node,
int index);

// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetOutput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index);

// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetOptionalInputTensor(context, node, kIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
//
// Deprecated. GetInput has the same functionality.
const TfLiteTensor* GetOptionalInputTensor(const TfLiteContext* context,
const TfLiteNode* node, int index);

Expand All @@ -50,14 +78,46 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
}

#ifndef TF_LITE_STATIC_MEMORY
// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetTemporary(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->temporaries->data[index]];
if (index >= 0 && index < node->temporaries->size) {
const int tensor_index = node->temporaries->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
}

// Note: You must check if result is not null:
//
// TfLiteTensor* my_tensor = GetIntermediates(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
//
// This is because the index might point to the optional tensor constant
// (kTfLiteOptionalTensor) in which case there is no tensor to return.
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]];
if (index >= 0 && index < node->intermediates->size) {
const int tensor_index = node->intermediates->data[index];
if (tensor_index != kTfLiteOptionalTensor) {
if (context->tensors != nullptr) {
return &context->tensors[tensor_index];
}
}
}
return nullptr;
}

inline int NumIntermediates(const TfLiteNode* node) {
return node->intermediates->size;
}
Expand Down

0 comments on commit 46d5b08

Please sign in to comment.