Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tflite]: Insert nullptr checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332518902
Change-Id: I92eb164a6101ac3cca66090061a9b56a97288236
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 18, 2020
1 parent ed69c61 commit cd31fd0
Showing 1 changed file with 16 additions and 7 deletions.
23 changes: 16 additions & 7 deletions tensorflow/lite/micro/test_helpers.cc
Expand Up @@ -601,7 +601,8 @@ TfLiteStatus SimpleStatefulOp::Prepare(TfLiteContext* context,
OpData* data = reinterpret_cast<OpData*>(node->user_data);

// Make sure that the input is in uint8_t with at least 1 data entry.
const TfLiteTensor* input = tflite::GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
if (input->type != kTfLiteUInt8) return kTfLiteError;
if (NumElements(input->dims) == 0) return kTfLiteError;

Expand All @@ -622,7 +623,8 @@ TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context,
OpData* data = reinterpret_cast<OpData*>(node->user_data);
*data->invoke_count += 1;

const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
const uint8_t* input_data = GetTensorData<uint8_t>(input);
int size = NumElements(input->dims);

Expand All @@ -641,9 +643,13 @@ TfLiteStatus SimpleStatefulOp::Invoke(TfLiteContext* context,
}
}

TfLiteTensor* median = GetOutput(context, node, kMedianTensor);
TfLiteTensor* median;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kMedianTensor, &median));
uint8_t* median_data = GetTensorData<uint8_t>(median);
TfLiteTensor* invoke_count = GetOutput(context, node, kInvokeCount);
TfLiteTensor* invoke_count;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kInvokeCount, &invoke_count));
int32_t* invoke_count_data = GetTensorData<int32_t>(invoke_count);

median_data[0] = sorting_buffer[size / 2];
Expand Down Expand Up @@ -681,11 +687,14 @@ TfLiteStatus MockCustom::Prepare(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus MockCustom::Invoke(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = tflite::GetInput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const int32_t* input_data = input->data.i32;
const TfLiteTensor* weight = tflite::GetInput(context, node, 1);
const TfLiteTensor* weight;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &weight));
const uint8_t* weight_data = weight->data.uint8;
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
int32_t* output_data = output->data.i32;
output_data[0] =
0; // Catch output tensor sharing memory with an input tensor
Expand Down

0 comments on commit cd31fd0

Please sign in to comment.