Skip to content
Permalink
Browse files Browse the repository at this point in the history
[tflite]: Insert nullptr checks when obtaining tensors.
As part of ongoing refactoring, `tflite::GetInput`, `tflite::GetOutput`, `tflite::GetTemporary` and `tflite::GetIntermediates` will return `nullptr` in some cases. Hence, we insert the `nullptr` checks on all usages.

We also insert `nullptr` checks on usages of `tflite::GetVariableInput` and `tflite::GetOptionalInputTensor` but only in the cases where there is no obvious check that `nullptr` is acceptable (that is, we only insert the check for the output of these two functions if the tensor is accessed as if it is always not `nullptr`).

PiperOrigin-RevId: 332521299
Change-Id: I29af455bcb48d0b92e58132d951a3badbd772d56
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Sep 18, 2020
1 parent fff2c83 commit 1970c21
Show file tree
Hide file tree
Showing 83 changed files with 2,722 additions and 1,205 deletions.
132 changes: 88 additions & 44 deletions tensorflow/lite/kernels/activations.cc
Expand Up @@ -252,8 +252,10 @@ void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) {
TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

return context->ResizeTensor(context, output,
Expand All @@ -272,8 +274,10 @@ TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) {
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8) {
Expand All @@ -300,12 +304,14 @@ void HardSwishFree(TfLiteContext* context, void* buffer) {

TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_STATUS(GenericPrepare(context, node));
TfLiteTensor* output = GetOutput(context, node, 0);
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));

if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);
HardSwishParams* params = &data->params;
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
params->input_zero_point = input->params.zero_point;
params->output_zero_point = output->params.zero_point;
const float input_scale = input->params.scale;
Expand Down Expand Up @@ -337,8 +343,10 @@ TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data);
Expand Down Expand Up @@ -366,8 +374,10 @@ TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

if (kernel_type == kFixedPointOptimized) {
Expand Down Expand Up @@ -451,8 +461,10 @@ TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

if (kernel_type == kFixedPointOptimized) {
Expand Down Expand Up @@ -546,8 +558,10 @@ TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
if (output->type == kTfLiteInt16) {
TF_LITE_ENSURE(context, input->type == kTfLiteInt8 ||
input->type == kTfLiteUInt8 ||
Expand Down Expand Up @@ -614,8 +628,10 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {

TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) {
Expand Down Expand Up @@ -650,9 +666,12 @@ TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);

TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type);
Expand Down Expand Up @@ -704,8 +723,10 @@ TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
Expand All @@ -732,8 +753,10 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
Expand Down Expand Up @@ -763,8 +786,10 @@ template <KernelType kernel_type>
TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
HardSwishData* data = static_cast<HardSwishData*>(node->user_data);

const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
Expand Down Expand Up @@ -814,8 +839,10 @@ TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
Expand Down Expand Up @@ -845,8 +872,10 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) {
template <KernelType kernel_type>
TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
Expand Down Expand Up @@ -919,8 +948,10 @@ template <KernelType kernel_type>
TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) {
OpData* data = reinterpret_cast<OpData*>(node->user_data);

const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
if (kernel_type == kReference) {
Expand Down Expand Up @@ -1067,8 +1098,10 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data);
SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data);

const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));

switch (input->type) {
case kTfLiteFloat32: {
Expand Down Expand Up @@ -1122,8 +1155,10 @@ template <KernelType kernel_type>
TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) {
const LogSoftmaxOpData* data =
reinterpret_cast<LogSoftmaxOpData*>(node->user_data);
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
SoftmaxParams op_params;
Expand Down Expand Up @@ -1183,9 +1218,12 @@ T ApplyPrelu(T input, T alpha) {

template <KernelType kernel_type>
TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
const TfLiteTensor* alpha = GetInput(context, node, 1);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
const TfLiteTensor* alpha;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data);
switch (input->type) {
case kTfLiteFloat32: {
Expand Down Expand Up @@ -1294,8 +1332,10 @@ void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output,
}

TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const auto* params =
reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data);
const LeakyReluOpData* data =
Expand Down Expand Up @@ -1332,8 +1372,10 @@ TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data);

// Use LUT to handle quantized elu path.
Expand All @@ -1346,8 +1388,10 @@ TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
}

TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
switch (input->type) {
case kTfLiteFloat32: {
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
Expand Down
24 changes: 18 additions & 6 deletions tensorflow/lite/kernels/add.cc
Expand Up @@ -91,9 +91,15 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));

TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
output->type = input2->type;
Expand Down Expand Up @@ -358,9 +364,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);

const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
const TfLiteTensor* input2;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor2, &input2));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));

if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
EvalAdd<kernel_type>(context, node, params, data, input1, input2, output);
Expand Down
20 changes: 16 additions & 4 deletions tensorflow/lite/kernels/add_n.cc
Expand Up @@ -33,13 +33,18 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE(context, num_inputs >= 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
output->type = input1->type;

// Check that all input tensors have the same shape and type.
for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
const TfLiteTensor* input = GetInput(context, node, i);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
}
Expand All @@ -55,15 +60,22 @@ template <typename T>
void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
// TODO(haoliang): Initialize all_inputs only once during init.
VectorOfTensors<T> all_inputs(*context, *node->inputs);
// Safe to use unchecked since caller checks that tensor is valid
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
int num_inputs = NumInputs(node);
// Safe to use unchecked since caller checks that tensor is valid
const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
GetTensorData<T>(output));
}

TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
const TfLiteTensor* input1;
TF_LITE_ENSURE_OK(context,
GetInputSafe(context, node, kInputTensor1, &input1));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
if (output->type == kTfLiteFloat32) {
EvalAddN<float>(context, node);
} else if (output->type == kTfLiteInt32) {
Expand Down

0 comments on commit 1970c21

Please sign in to comment.