Skip to content

Commit

Permalink
Merge pull request #47284 from ddavis-2015:Elu-pr5
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 359044163
Change-Id: Ib1b74b5286a1dc6b989c06083f030fbddc4750ed
  • Loading branch information
tensorflower-gardener committed Feb 23, 2021
2 parents 3f113a3 + 97eeecc commit 59da245
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 111 deletions.
16 changes: 16 additions & 0 deletions tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ cc_library(
"detection_postprocess.cc",
"elementwise.cc",
"exp.cc",
"elu.cc",
"floor.cc",
"l2norm.cc",
"logical.cc",
Expand Down Expand Up @@ -524,6 +525,21 @@ cc_test(
],
)

cc_test(
name = "elu_test",
srcs = [
"elu_test.cc",
],
deps = [
":kernel_runner",
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:debug_log",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

cc_test(
name = "exp_test",
srcs = ["exp_test.cc"],
Expand Down
124 changes: 65 additions & 59 deletions tensorflow/lite/micro/kernels/elu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License.

#include <algorithm>
#include <cmath>
#include <functional>
#include <limits>

#include "tensorflow/lite/c/common.h"
Expand All @@ -28,23 +27,26 @@ limitations under the License.
#include "tensorflow/lite/micro/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace micro {
namespace activations {
namespace {

// Input/output tensor index.
constexpr int kInputTensor = 0;
constexpr int kOutputTensor = 0;

// OLD-TODO(b/142762739): We should figure out a multi-threading plan for most
// of the activation ops below.

struct OpData {
uint8_t table[256] = {0};
int8_t table[256];
};

using TransformFunc = float (*)(float);

template <typename T>
void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
TfLiteTensor* output,
const std::function<float(float)>& transform) {
static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit");
void PopulateLookupTable(const TfLiteTensor* input, const TfLiteTensor* output,
const TransformFunc transform, OpData* data) {
if (sizeof(T) != 1) TF_LITE_FATAL("Lookup table valid only for 8bit");

const float inverse_scale = 1 / output->params.scale;
int32_t maxval = std::numeric_limits<T>::max();
int32_t minval = std::numeric_limits<T>::min();
Expand All @@ -56,90 +58,94 @@ void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input,
const int32_t quantized =
static_cast<int32_t>(rescaled + output->params.zero_point);
data->table[static_cast<uint8_t>(static_cast<T>(val))] =
static_cast<uint8_t>(
static_cast<T>(std::max(std::min(maxval, quantized), minval)));
static_cast<T>(std::max(std::min(maxval, quantized), minval));
}
}

// OLD-TODO(b/143696793): move this to optimized_ops.
void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input,
TfLiteTensor* output) {
const int size =
MatchingFlatSize(GetTensorShape(input), GetTensorShape(output));
uint8_t* output_data = GetTensorData<uint8_t>(output);
const uint8_t* input_data = GetTensorData<uint8_t>(input);
int i = 0;

for (; i < size; ++i) {
output_data[i] = data->table[input_data[i]];
void EvalUsingLookupTable(const OpData* data, const TfLiteEvalTensor* input,
TfLiteEvalTensor* output) {
const int size = MatchingFlatSize(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorShape(output));
int8_t* output_data = tflite::micro::GetTensorData<int8_t>(output);
const int8_t* input_data = tflite::micro::GetTensorData<int8_t>(input);

for (int i = 0; i < size; ++i) {
output_data[i] = data->table[static_cast<uint8_t>(input_data[i])];
}
}

} // namespace

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
return nullptr;
}

TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
TF_LITE_ENSURE_OK(context,
GetOutputSafe(context, node, kOutputTensor, &output));
TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);

return kTfLiteError;
}

TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
OpData* data = reinterpret_cast<OpData*>(node->user_data);

// Use LUT to handle quantized elu path.
if (input->type == kTfLiteInt8) {
PopulateLookupTable<int8_t>(data, input, output, [](float value) {
return value < 0.0 ? std::exp(value) - 1.0f : value;
});
OpData* data = static_cast<OpData*>(node->user_data);
TransformFunc transform = [](float value) {
return value < 0.0f ? std::exp(value) - 1.0f : value;
};
PopulateLookupTable<int8_t>(input, output, transform, data);
}
return GenericPrepare(context, node);

return kTfLiteOk;
}

void* EluInit(TfLiteContext* context, const char* buffer, size_t length) {
// This is a builtin op, so we don't use the contents in 'buffer', if any.
// Instead, we allocate a new object to carry information from Prepare() to
// Eval().
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) {
return CalculateOpData(context, node);
}

TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input;
TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
TfLiteTensor* output;
TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
switch (input->type) {
case kTfLiteFloat32: {
optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input),
GetTensorShape(output), GetTensorData<float>(output));
reference_ops::Elu(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<float>(input),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<float>(output));
return kTfLiteOk;
}
case kTfLiteInt8: {
OpData* data = reinterpret_cast<OpData*>(node->user_data);
const OpData* data = static_cast<OpData*>(node->user_data);
EvalUsingLookupTable(data, input, output);
return kTfLiteOk;
}
default:
TF_LITE_KERNEL_LOG(
context, "Only float32 and int8 is supported currently, got %s.",
context, "ELU only supports float32 and int8 currently, got %s.",
TfLiteTypeGetName(input->type));
return kTfLiteError;
}
}

} // namespace activations
} // namespace

TfLiteRegistration* Register_ELU() { return nullptr; }
TfLiteRegistration Register_ELU() {
return {/*init=*/EluInit,
/*free=*/nullptr,
/*prepare=*/EluPrepare,
/*invoke=*/EluEval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/0};
}

} // namespace micro
} // namespace ops
} // namespace tflite

0 comments on commit 59da245

Please sign in to comment.