Skip to content

Commit

Permalink
Fix the issues with the quantized tests in batch_matmul_test.cc.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 635942430
  • Loading branch information
qukhan authored and tensorflower-gardener committed May 21, 2024
1 parent 86f326f commit 42b11fd
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 63 deletions.
4 changes: 3 additions & 1 deletion tensorflow/lite/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1458,9 +1458,11 @@ cc_test(
deps = [
":test_main",
":test_util",
"//tensorflow/lite:string",
"//tensorflow/lite/c:c_api_types",
"//tensorflow/lite/c:common",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_googletest//:gtest",
"@flatbuffers",
],
)

Expand Down
181 changes: 119 additions & 62 deletions tensorflow/lite/kernels/batch_matmul_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,19 @@ limitations under the License.
#include <stdint.h>

#include <initializer_list>
#include <limits>
#include <map>
#include <numeric>
#include <type_traits>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_type.h"

namespace tflite {

Expand Down Expand Up @@ -442,14 +448,19 @@ class HybridBatchMatMulOpModel : public SingleOpModel {
CreateBatchMatMulOptions(builder_, adj_x, adj_y,
asymmetric_quantize_inputs)
.Union());
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)});
BuildInterpreter({GetShape(lhs_id_), GetShape(rhs_id_)},
/*num_threads=*/-1,
/*allow_fp32_relax_to_fp16=*/false,
/*apply_delegate=*/false);
}
void SetWeights(const std::vector<float>& data) {
SymmetricQuantizeAndPopulate(rhs_id_, data);
AllocateAndDelegate(true);
}

void SetSignedWeights(std::initializer_list<float> f) {
SignedSymmetricQuantizeAndPopulate(rhs_id_, f);
AllocateAndDelegate(true);
}

void SetInput(const std::vector<float>& f) { PopulateTensor(lhs_id_, f); }
Expand Down Expand Up @@ -499,14 +510,14 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
196,
196,
196,
246,
246,
246,
193,
193,
193,
247,
247,
247,
},
/*max_abs_error=*/0.64f)));
/*max_abs_error=*/3.f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
}

Expand Down Expand Up @@ -726,12 +737,12 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 24, 24, //
58, 58, 58, //
196, 196, 196, //
246, 246, 246, //
23, 23, 23, //
57, 57, 57, //
193, 193, 193, //
247, 247, 247, //
},
/*max_abs_error=*/1.3f)));
/*max_abs_error=*/3.f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
}

Expand All @@ -742,11 +753,16 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
/*rhs=*/{TensorType_INT8, {10, 9}, 0, 0, 10.0 / 127.0, 0});

m.SetSignedWeights({
1, 1, 1, 17, 17, 17, 26, 26, 26, 2, 2, 2, 18, 18, 18, 27, 27, 27,
3, 3, 3, 19, 19, 19, 28, 28, 28, 4, 4, 4, 20, 20, 20, 29, 29, 29,
5, 5, 5, 21, 21, 21, 30, 30, 30, 6, 6, 6, 22, 22, 22, 31, 31, 31,
7, 7, 7, 23, 23, 23, 32, 32, 32, 8, 8, 8, 24, 24, 24, 33, 33, 33,
9, 9, 9, 25, 25, 25, 34, 34, 34, 10, 10, 10, 26, 26, 26, 35, 35, 35,
1, 1, 1, 17, 17, 17, 26, 26, 26, //
2, 2, 2, 18, 18, 18, 27, 27, 27, //
3, 3, 3, 19, 19, 19, 28, 28, 28, //
4, 4, 4, 20, 20, 20, 29, 29, 29, //
5, 5, 5, 21, 21, 21, 30, 30, 30, //
6, 6, 6, 22, 22, 22, 31, 31, 31, //
7, 7, 7, 23, 23, 23, 32, 32, 32, //
8, 8, 8, 24, 24, 24, 33, 33, 33, //
9, 9, 9, 25, 25, 25, 34, 34, 34, //
10, 10, 10, 26, 26, 26, 35, 35, 35,
});

m.SetInput({
Expand All @@ -761,12 +777,12 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
23, 23, 23, 295, 295, 295, 449, 449, 449, //
60, 60, 60, 364, 364, 364, 533, 533, 533, //
195, 195, 195, 1429, 1429, 1429, 2124, 2124, 2124, //
250, 250, 250, 1512, 1512, 1512, 2213, 2213, 2213 //
23, 23, 23, 295, 295, 295, 448, 448, 448, //
57, 57, 57, 361, 361, 361, 532, 532, 532, //
193, 193, 193, 1425, 1425, 1425, 2118, 2118, 2118, //
247, 247, 247, 1511, 1511, 1511, 2222, 2222, 2222 //
},
/*max_abs_error=*/1.3f)));
/*max_abs_error=*/10.0f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
}

Expand All @@ -777,9 +793,27 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
/*rhs=*/{TensorType_INT8, {2, 10, 3}, 0, 0, 10.0 / 127.0, 0});

m.SetSignedWeights({
1, -3, 1, 2, -2, 2, 3, -1, 3, 4, 0, 4, 5, 1, 5, 6, 2, 6, 7, 3,
7, 8, 4, 8, 9, 5, 9, 10, 6, 10, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4,
4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
1, -3, 1, //
2, -2, 2, //
3, -1, 3, //
4, 0, 4, //
5, 1, 5, //
6, 2, 6, //
7, 3, 7, //
8, 4, 8, //
9, 5, 9, //
10, 6, 10, //

1, 1, 1, //
2, 2, 2, //
3, 3, 3, //
4, 4, 4, //
5, 5, 5, //
6, 6, 6, //
7, 7, 7, //
8, 8, 8, //
9, 9, 9, //
10, 10, 10,
});

m.SetInput({
Expand All @@ -791,12 +825,12 @@ TEST_P(HybridAsymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, -45, 24, //
58, -18, 58, //
24, 24, 24, //
58, 58, 58, //
23, -45, 23, //
57, -19, 57, //
23, 23, 23, //
57, 57, 57, //
},
/*max_abs_error=*/0.64f)));
/*max_abs_error=*/1.5f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
}

Expand Down Expand Up @@ -832,14 +866,14 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, SimpleTestQuantizedInt8) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
194,
194,
194,
248,
248,
248,
193,
193,
193,
247,
247,
247,
},
/*max_abs_error=*/0.64f)));
/*max_abs_error=*/1.5f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
}

Expand All @@ -866,12 +900,12 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastWeights) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, 24, 24, //
56, 56, 56, //
194, 194, 194, //
248, 248, 248, //
23, 23, 23, //
57, 57, 57, //
193, 193, 193, //
247, 247, 247, //
},
/*max_abs_error=*/1.3f)));
/*max_abs_error=*/1.5f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
}

Expand All @@ -883,11 +917,16 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
{TensorType_FLOAT32}, false);

m.SetSignedWeights({
1, 1, 1, 17, 17, 17, 26, 26, 26, 2, 2, 2, 18, 18, 18, 27, 27, 27,
3, 3, 3, 19, 19, 19, 28, 28, 28, 4, 4, 4, 20, 20, 20, 29, 29, 29,
5, 5, 5, 21, 21, 21, 30, 30, 30, 6, 6, 6, 22, 22, 22, 31, 31, 31,
7, 7, 7, 23, 23, 23, 32, 32, 32, 8, 8, 8, 24, 24, 24, 33, 33, 33,
9, 9, 9, 25, 25, 25, 34, 34, 34, 10, 10, 10, 26, 26, 26, 35, 35, 35,
1, 1, 1, 17, 17, 17, 26, 26, 26, //
2, 2, 2, 18, 18, 18, 27, 27, 27, //
3, 3, 3, 19, 19, 19, 28, 28, 28, //
4, 4, 4, 20, 20, 20, 29, 29, 29, //
5, 5, 5, 21, 21, 21, 30, 30, 30, //
6, 6, 6, 22, 22, 22, 31, 31, 31, //
7, 7, 7, 23, 23, 23, 32, 32, 32, //
8, 8, 8, 24, 24, 24, 33, 33, 33, //
9, 9, 9, 25, 25, 25, 34, 34, 34, //
10, 10, 10, 26, 26, 26, 35, 35, 35,
});

m.SetInput({
Expand All @@ -902,12 +941,12 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastBigWeights) {
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear(
{
23, 23, 23, 296, 296, 296, 451, 451, 451, //
58, 58, 58, 362, 362, 362, 529, 529, 529, //
193, 193, 193, 1424, 1424, 1424, 2118, 2118, 2118, //
253, 253, 253, 1519, 1519, 1519, 2223, 2223, 2223 //
23, 23, 23, 295, 295, 295, 448, 448, 448, //
57, 57, 57, 361, 361, 361, 532, 532, 532, //
193, 193, 193, 1425, 1425, 1425, 2118, 2118, 2118, //
247, 247, 247, 1511, 1511, 1511, 2222, 2222, 2222 //
},
/*max_abs_error=*/1.3f)));
/*max_abs_error=*/10.0f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 9}));
}

Expand All @@ -919,9 +958,27 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {
{TensorType_FLOAT32}, false);

m.SetSignedWeights({
1, -3, 1, 2, -2, 2, 3, -1, 3, 4, 0, 4, 5, 1, 5, 6, 2, 6, 7, 3,
7, 8, 4, 8, 9, 5, 9, 10, 6, 10, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4,
4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10,
1, -3, 1, //
2, -2, 2, //
3, -1, 3, //
4, 0, 4, //
5, 1, 5, //
6, 2, 6, //
7, 3, 7, //
8, 4, 8, //
9, 5, 9, //
10, 6, 10, //

1, 1, 1, //
2, 2, 2, //
3, 3, 3, //
4, 4, 4, //
5, 5, 5, //
6, 6, 6, //
7, 7, 7, //
8, 8, 8, //
9, 9, 9, //
10, 10, 10,
});

m.SetInput({
Expand All @@ -933,12 +990,12 @@ TEST_P(HybridSymmetricBatchMatMulOpTest, QuantizedInt8BroadcastInputs) {

EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
{
24, -45, 24, //
56, -19, 56, //
24, 24, 24, //
56, 56, 56, //
23, -45, 23, //
57, -19, 57, //
23, 23, 23, //
57, 57, 57, //
},
/*max_abs_error=*/0.64f)));
/*max_abs_error=*/1.5f)));
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 3}));
}

Expand Down

0 comments on commit 42b11fd

Please sign in to comment.