Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

micro: port operator ELU kernel from lite with test #47284

Merged
merged 1 commit into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 16 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ cc_library(
"detection_postprocess.cc",
"elementwise.cc",
"exp.cc",
"elu.cc",
"floor.cc",
"l2norm.cc",
"logical.cc",
Expand Down Expand Up @@ -521,6 +522,21 @@ cc_test(
],
)

cc_test(
name = "elu_test",
srcs = [
"elu_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
Expand Down
124 changes: 65 additions & 59 deletions tensorflow/lite/micro/kernels/elu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <algorithm>
#include <cmath>
#include <functional>
#include <limits>

#include "tensorflow/lite/c/common.h"
Expand All @@ -28,23 +27,26 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace micro {
namespace activations {
namespace {

// Input/output tensor index.
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;

// OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
// of the activation ops below.

struct OpData {
uint8_t table[256] = {0};
int8_t table[256];
};

using TransformFunc = float (*)(float);

template <typename T>
void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
TfLiteTensor* output,
const std::function<float(float)>& transform) {
static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
void PopulateLookupTable(const TfLiteTensor* input, const TfLiteTensor* output,
const TransformFunc transform, OpData* data) {
if (sizeof(T) != 1) TF_LITE_FATAL("Lookup table valid only for 8bit");

const float inverse_scale = 1 / output->params.scale;
int32_t maxval = std::numeric_limits<T>::max();
int32_t minval = std::numeric_limits<T>::min();
Expand All @@ -56,90 +58,94 @@ void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
const int32_t quantized =
static_cast<int32_t>(rescaled + output->params.zero_point);
data->table[static_cast<uint8_t>(static_cast<T>(val))] =
static_cast<uint8_t>(
static_cast<T>(std::max(std::min(maxval, quantized), minval)));
static_cast<T>(std::max(std::min(maxval, quantized), minval));
}
}

// OLD-TODO(b/143696793): move this to optimized_ops.
void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
TfLiteTensor* output) {
const int size =
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
uint8_t* output_data = GetTensorData<uint8_t>(output);
const uint8_t* input_data = GetTensorData<uint8_t>(input);
int i = 0;

for (; i < size; ++i) {
output_data[i] = data->table[input_data[i]];
void EvalUsingLookupTable(const OpData* data, const TfLiteEvalTensor* input,
TfLiteEvalTensor* output) {
const int size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorShape(output));
int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);

for (int i = 0; i < size; ++i) {
output_data[i] = data->table[static_cast<uint8_t>(input_data[i])];
}
}

} // namespace

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
return nullptr;
}

TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

return kTfLiteError;
}

TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data);

// Use LUT to handle quantized elu path.
if (input->type == kTfLiteInt8) {
PopulateLookupTable<int8_t>(data, input, output, [](float value) {
return value < 0.0 ? std::exp(value) - 1.0f : value;
});
OpData* data = static_cast<OpData*>(node->user_data);
TransformFunc transform = [](float value) {
return value < 0.0f ? std::exp(value) - 1.0f : value;
};
PopulateLookupTable<int8_t>(input, output, transform, data);
}
return GenericPrepare(context, node);

return kTfLiteOk;
}

void* EluInit(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}

TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32: {
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
reference_ops::Elu(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
return kTfLiteOk;
}
case kTfLiteInt8: {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const OpData* data = static_cast<OpData*>(node->user_data);
EvalUsingLookupTable(data, input, output);
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 and int8 is supported currently, got %s.",
context, "ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

} // namespace activations
} // namespace

TfLiteRegistration* Register_ELU() { return nullptr; }
TfLiteRegistration Register_ELU() {
return {/*init=*/EluInit,
/*free=*/nullptr,
/*prepare=*/EluPrepare,
/*invoke=*/EluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}

} // namespace micro
} // namespace ops
} // namespace tflite