Skip to content

Commit

Permalink
Reverts 1165601
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636479468
  • Loading branch information
tensorflower-gardener committed May 23, 2024
1 parent 65ecf09 commit 9aafb41
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 125 deletions.
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/ir/tfl_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,8 @@ def TFL_EluOp: TFL_Op<"elu", [

def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
[Pure,
PredOpTrait<"value and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>,
TFL_OperandHasRank<0, 1>,
TFL_OperandHasRankAtLeast<1, 2>,
DynamicRangeQuantizedOpInterface,
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/mlir/lite/tests/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,14 @@ func.func @testEmbeddingLookup(%arg0 : tensor<?xi32>, %arg1 : tensor<?x?xf32>) -

// -----

func.func @testEmbeddingLookupValueAndResultElementTypeTraitFailed(%arg0 : tensor<?xi32>, %arg1 : tensor<?x?xi8>) -> tensor<?xf32> {
// expected-error @+1 {{'tfl.embedding_lookup' op failed to verify that value and output must have same element type}}
%0 = "tfl.embedding_lookup"(%arg0, %arg1) : (tensor<?xi32>,tensor<?x?xi8>) -> tensor<?xf32>
func.return %0 : tensor<?xf32>
}

// -----

func.func @testWrongQuantizedLocalResponseNormalization(%arg0 : tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>> {
// expected-error @+1 {{'tfl.local_response_normalization' op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x56x56x192x!quant.uniform<u8:f32, 2.000000e-02>>'}}
%0 = "tfl.local_response_normalization"(%arg0) {alpha = 9.99999974E-5 : f32, beta = 5.000000e-01 : f32, bias = 2.000000e+00 : f32, radius = 5 : i32} : (tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>) -> tensor<1x56x56x192x!quant.uniform<u8:f32, 0.02>>
Expand Down
29 changes: 1 addition & 28 deletions tensorflow/lite/kernels/embedding_lookup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
for (int i = 2; i < NumDimensions(value); i++) {
outputSize->data[i] = SizeOfDimension(value, i);
}

if (value->quantization.type == kTfLiteAffineQuantization) {
const auto qparams = static_cast<const TfLiteAffineQuantization*>(
value->quantization.params);
TF_LITE_ENSURE(context, qparams->scale != nullptr);
TF_LITE_ENSURE(context, qparams->zero_point != nullptr);
// Only support symmetric quantization for now.
TF_LITE_ENSURE(context, qparams->zero_point->data[0] == 0);
if (qparams->scale->size > 1 || qparams->zero_point->size > 1) {
// Per-axis quantization must have quantized_dimension == 0 and correct
// sizes for scale and zero_point.
TF_LITE_ENSURE(context, qparams->quantized_dimension == 0);
const int row_size = SizeOfDimension(value, 0);
TF_LITE_ENSURE(context, qparams->scale->size == row_size);
TF_LITE_ENSURE(context, qparams->zero_point->size == row_size);
}
}

return context->ResizeTensor(context, output, outputSize);
}

Expand Down Expand Up @@ -119,6 +101,7 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
const TfLiteTensor* lookup, const TfLiteTensor* value,
TfLiteTensor* output) {
const int row_size = SizeOfDimension(value, 0);
const double scaling_factor = value->params.scale;

// col_size after we flatten tensor into 2D.
int col_size = 1;
Expand All @@ -142,16 +125,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
// Dequantize embedding values.
// TODO(alanchiao): refactor scalar multiply into separate function
// for ease of adding a neon equivalent if ever necessary.
double scaling_factor = value->params.scale;
if (value->quantization.type == kTfLiteAffineQuantization) {
const auto qparams = static_cast<const TfLiteAffineQuantization*>(
value->quantization.params);
if (qparams->scale->size > 1) {
// get this row's scale for per-axis quantization
scaling_factor = qparams->scale->data[idx];
}
}

for (int j = 0; j < col_size; j++) {
output_ptr[j + i * col_size] =
value_ptr[j + idx * col_size] * scaling_factor;
Expand Down
103 changes: 6 additions & 97 deletions tensorflow/lite/kernels/embedding_lookup_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ License.
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "tensorflow/lite/core/interpreter.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/test_util.h"
#include "tensorflow/lite/schema/schema_generated.h"
Expand All @@ -37,22 +38,12 @@ using ::testing::ElementsAreArray;

class BaseEmbeddingLookupOpModel : public SingleOpModel {
public:
BaseEmbeddingLookupOpModel(
std::initializer_list<int> index_shape,
std::initializer_list<int> weight_shape,
TensorType weight_type = TensorType_FLOAT32,
TensorType output_type = TensorType_FLOAT32,
const std::vector<float>& per_channel_quantization_scales = {}) {
BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape,
std::initializer_list<int> weight_shape,
TensorType weight_type = TensorType_FLOAT32,
TensorType output_type = TensorType_FLOAT32) {
input_ = AddInput(TensorType_INT32);
if (per_channel_quantization_scales.empty()) {
weight_ = AddInput(weight_type);
} else {
std::vector<int64_t> per_channel_quantization_offsets(
per_channel_quantization_scales.size(), 0);
weight_ = AddInput({weight_type, weight_shape, 0, 0, 0, 0, true,
per_channel_quantization_scales,
per_channel_quantization_offsets, 0});
}
weight_ = AddInput(weight_type);
output_ = AddOutput(output_type);
SetBuiltinOp(BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOptions_NONE, 0);
BuildInterpreter({index_shape, weight_shape});
Expand Down Expand Up @@ -110,22 +101,6 @@ class HybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
}
};

class PerAxisHybridEmbeddingLookupOpModel : public BaseEmbeddingLookupOpModel {
public:
PerAxisHybridEmbeddingLookupOpModel(
std::initializer_list<int> index_shape,
std::initializer_list<int> weight_shape,
const std::vector<float>& per_channel_quantization_scales,
TensorType type)
: BaseEmbeddingLookupOpModel(index_shape, weight_shape, type,
TensorType_FLOAT32,
per_channel_quantization_scales) {}

void SetSignedWeight(std::initializer_list<float> data) {
PerChannelSymmetricQuantizeAndPopulate(weight_, data);
}
};

// TODO(ahentz): write more tests that exercise the details of the op, such as
// lookup errors and variable input shapes.
TEST(EmbeddingLookupOpTest, SimpleTest) {
Expand Down Expand Up @@ -286,71 +261,5 @@ TEST(EmbeddingLookupHybridOpTest, Simple3DTestQuantized) {
}));
}

TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple2DTestInt8) {
PerAxisHybridEmbeddingLookupOpModel m(
{3}, {3, 8}, {0.00102, 0.0089, 0.016772}, TensorType_INT8);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});

ASSERT_EQ(m.Invoke(), kTfLiteOk);

EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
kTestTolerance)));
}

TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple3DTestInt8) {
PerAxisHybridEmbeddingLookupOpModel m(
{3}, {3, 2, 4}, {0.00102, 0.0089, 0.016772}, TensorType_INT8);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});

ASSERT_EQ(m.Invoke(), kTfLiteOk);

EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
kTestTolerance)));
}

TEST(PerAxisHybridEmbeddingLookupHybridOpTest, PerAxisSimple4DTestInt8) {
PerAxisHybridEmbeddingLookupOpModel m(
{3}, {3, 2, 2, 2}, {0.00102, 0.0089, 0.016772}, TensorType_INT8);
m.SetInput({1, 0, 2});
m.SetSignedWeight({
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
});

ASSERT_EQ(m.Invoke(), kTfLiteOk);

EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray(ArrayFloatNear(
{
1.00, -1.01, 1.02, 1.03, 1.10, 1.11, 1.12, 1.13, // Row 1
0.00, 0.01, 0.02, 0.03, 0.10, 0.11, 0.12, 0.13, // Row 0
2.00, 2.01, 2.02, 2.03, 2.10, 2.11, 2.12, 2.13, // Row 2
},
kTestTolerance)));
}

} // namespace
} // namespace tflite

0 comments on commit 9aafb41

Please sign in to comment.