Skip to content
Permalink
Browse files Browse the repository at this point in the history
Prevent a division by 0 in division ops.
PiperOrigin-RevId: 385223169
Change-Id: Ia4228960b5d2aa44480385f74bdd70d21a3613c3
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 16, 2021
1 parent 9579070 commit 1e206ba
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion tensorflow/lite/kernels/div.cc
Expand Up @@ -216,9 +216,23 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));

if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
// TODO(b/193904910): This can written with C++ templates
#define TF_LITE_CHECK_DIV_NON_ZERO(data_type) \
const auto* input2_data = GetTensorData<data_type>(input2); \
const size_t input2_elements = input2->bytes / sizeof(data_type); \
for (size_t i = 0; i < input2_elements; i++) { \
TF_LITE_ENSURE(context, input2_data[i] != 0); \
}

if (output->type == kTfLiteFloat32) {
// Div by zero seems ok in this case, just like in TF case infinities are
// returned. So we don't do a check at this point.
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteInt32) {
TF_LITE_CHECK_DIV_NON_ZERO(int32_t);
EvalDiv<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8) {
TF_LITE_CHECK_DIV_NON_ZERO(uint8_t);
TF_LITE_ENSURE_OK(
context, EvalQuantized<kernel_type>(context, node, params, data, input1,
input2, output));
Expand All @@ -229,6 +243,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
output->type);
return kTfLiteError;
}
#undef TF_LITE_CHECK_DIV_NON_ZERO

return kTfLiteOk;
}
Expand Down

0 comments on commit 1e206ba

Please sign in to comment.