Skip to content
Permalink
Browse files Browse the repository at this point in the history
Improve to cover scale value greater than one
PiperOrigin-RevId: 433050921
  • Loading branch information
sngyhan authored and tensorflower-gardener committed Mar 7, 2022
1 parent 39907ef commit a989426
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
19 changes: 14 additions & 5 deletions tensorflow/lite/kernels/comparisons.cc
Expand Up @@ -81,6 +81,17 @@ TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context,
return ComparisonPrepareCommon(context, node, true);
}

void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
int* left_shift) {
if (double_multiplier < 1.0) {
QuantizeMultiplierSmallerThanOneExp(double_multiplier, quantized_multiplier,
left_shift);
} else {
QuantizeMultiplierGreaterThanOne(double_multiplier, quantized_multiplier,
left_shift);
}
}

template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
TfLiteTensor* output, bool requires_broadcast) {
Expand All @@ -90,13 +101,11 @@ void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
const int left_shift = 8;

int32 input1_multiplier;
int input1_shift;
QuantizeMultiplierSmallerThanOneExp(input1->params.scale,
&input1_multiplier, &input1_shift);
int32 input2_multiplier;
int input1_shift;
int input2_shift;
QuantizeMultiplierSmallerThanOneExp(input2->params.scale,
&input2_multiplier, &input2_shift);
QuantizeMultiplier(input1->params.scale, &input1_multiplier, &input1_shift);
QuantizeMultiplier(input2->params.scale, &input2_multiplier, &input2_shift);

ComparisonParams op_params;
op_params.left_shift = left_shift;
Expand Down
20 changes: 20 additions & 0 deletions tensorflow/lite/kernels/comparisons_test.cc
Expand Up @@ -653,6 +653,26 @@ TEST(ComparisonsTest, QuantizedInt8GreaterWithBroadcast) {
}
}

TEST(ComparisonsTest,
QuantizedInt8GreaterWithBroadcastMultiplierGreaterThanOne) {
const float kMin = -127.f;
const float kMax = 127.f;
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
ComparisonOpModel model({TensorType_INT8, test_shapes[i], kMin, kMax},
{TensorType_INT8, {}, kMin, kMax}, TensorType_INT8,
BuiltinOperator_GREATER);
model.QuantizeAndPopulate<int8_t>(model.input1(),
{572, -2, -71, 8, 11, 20});
model.QuantizeAndPopulate<int8_t>(model.input2(), {8});
model.Invoke();
EXPECT_THAT(model.GetOutput(),
ElementsAre(true, false, false, false, true, true))
<< "With shape number " << i;
}
}

TEST(ComparisonsTest, QuantizedUInt8GreaterEqualWithBroadcast) {
const float kMin = -1.f;
const float kMax = 128.f;
Expand Down

0 comments on commit a989426

Please sign in to comment.