Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Micro transpose op ported and tested for TFLM #48192

Closed
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ node_modules
__pycache__
*.swp
.vscode/
venv/
cmake_build/
tensorflow/contrib/cmake/_build/
.idea/**
Expand Down
11 changes: 8 additions & 3 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,10 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
return ParseTanh(op, error_reporter, allocator, builtin_data);
}

case BuiltinOperator_TRANSPOSE: {
return ParseTranspose(op, error_reporter, allocator, builtin_data);
}

case BuiltinOperator_TRANSPOSE_CONV: {
return ParseTransposeConv(op, error_reporter, allocator, builtin_data);
}
Expand Down Expand Up @@ -806,7 +810,6 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_SLICE:
case BuiltinOperator_TILE:
case BuiltinOperator_TOPK_V2:
case BuiltinOperator_TRANSPOSE:
case BuiltinOperator_RANGE:
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
Expand Down Expand Up @@ -2059,8 +2062,10 @@ TfLiteStatus ParseTanh(const Operator*, ErrorReporter*, BuiltinDataAllocator*,
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseTranspose(const Operator*, ErrorReporter*,
BuiltinDataAllocator*, void**) {
TfLiteStatus ParseTranspose(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
return kTfLiteOk;
}

Expand Down
7 changes: 4 additions & 3 deletions tensorflow/lite/core/api/flatbuffer_conversions.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,10 @@ TfLiteStatus ParseSvdf(const Operator* op, ErrorReporter* error_reporter,
TfLiteStatus ParseTanh(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);

TfLiteStatus ParseTranspose(const Operator* op, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseTranspose(const Operator* op,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);

TfLiteStatus ParseTransposeConv(const Operator* op,
ErrorReporter* error_reporter,
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/micro/all_ops_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ AllOpsResolver::AllOpsResolver() {
AddSub();
AddSvdf();
AddTanh();
AddTranspose();
AddTransposeConv();
AddUnpack();
}
Expand Down
16 changes: 15 additions & 1 deletion tensorflow/lite/micro/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ cc_library(
"sub.cc",
"svdf_common.cc",
"tanh.cc",
"transpose.cc",
"transpose_conv.cc",
"unpack.cc",
"zeros_like.cc",
Expand Down Expand Up @@ -1089,6 +1090,19 @@ cc_test(
],
)

cc_test(
name = "transpose_test",
srcs = [
"transpose_test.cc"
],
deps = [
"//tensorflow/lite/c:common",
"//tensorflow/lite/micro:micro_framework",
"//tensorflow/lite/micro:op_resolvers",
"//tensorflow/lite/micro/testing:micro_test",
],
)

cc_test(
name = "transpose_conv_test",
srcs = [
Expand Down Expand Up @@ -1130,4 +1144,4 @@ cc_test(
"//tensorflow/lite/micro:test_helpers",
"//tensorflow/lite/micro/testing:micro_test",
],
)
)
1 change: 1 addition & 0 deletions tensorflow/lite/micro/kernels/micro_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ TfLiteRegistration Register_SOFTMAX();
TfLiteRegistration Register_SPACE_TO_BATCH_ND();
TfLiteRegistration Register_SQUEEZE();
TfLiteRegistration Register_SVDF();
TfLiteRegistration Register_TRANSPOSE();
TfLiteRegistration Register_TRANSPOSE_CONV();
TfLiteRegistration Register_ZEROS_LIKE();

Expand Down
171 changes: 57 additions & 114 deletions tensorflow/lite/micro/kernels/transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,170 +12,113 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <stdint.h>

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/internal/types.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
#include "tensorflow/lite/micro/memory_helpers.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/kernels/internal/reference/transpose.h"

namespace tflite {
namespace ops {
namespace builtin {
namespace transpose {

// This file has two implementations of Transpose.
enum KernelType {
kReference,
kGenericOptimized,
};
namespace {

constexpr int kInputTensor = 0;
constexpr int kPermTensor = 1;
constexpr int kOutputTensor = 0;

struct TransposeContext {
TransposeContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, 0);
perm = GetInput(context, node, 1);
output = GetOutput(context, node, 0);
}
const TfLiteTensor* input;
const TfLiteTensor* perm;
TfLiteTensor* output;
TransposeContext(TfLiteContext* context, TfLiteNode* node) {
input = GetInput(context, node, kInputTensor);
perm = GetInput(context, node, kPermTensor);
output = GetOutput(context, node, kOutputTensor);
}
const TfLiteTensor* input;
const TfLiteTensor* perm;
TfLiteTensor* output;
};

TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
TransposeContext* op_context) {
int dims = NumDimensions(op_context->input);
const int* perm_data = GetTensorData<int32_t>(op_context->perm);

// Ensure validity of the permutations tensor as a 1D tensor.
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1);
TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims);
for (int idx = 0; idx < dims; ++idx) {
TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims),
"Transpose op permutations array is out of bounds.");
}

// Determine size of output tensor.
TfLiteIntArray* input_size = op_context->input->dims;
TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size);
for (int idx = 0; idx < dims; ++idx) {
output_size->data[idx] = input_size->data[perm_data[idx]];
}

return context->ResizeTensor(context, op_context->output, output_size);
}

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);

TransposeContext op_context(context, node);
TransposeContext op_context(context, node);

// Ensure validity of input tensor.
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
"Transpose op only supports 1D-5D input arrays.");
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.output->type);
// Ensure validity of input tensor.
TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 5,
"Transpose op only supports 1D-5D input arrays.");
TF_LITE_ENSURE_TYPES_EQ(context, op_context.input->type,
op_context.output->type);

if (!IsConstantTensor(op_context.perm)) {
SetTensorToDynamic(op_context.output);
return kTfLiteOk;
}
return ResizeOutputTensor(context, &op_context);
}

template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TransposeContext op_context(context, node);

// Resize the output tensor if the output tensor is dynamic.
if (IsDynamicTensor(op_context.output)) {
TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
}
// Retrieve the perm permutation array
const int32_t* perm_data = GetTensorData<int32_t>(op_context.perm);

const int* perm_data = GetTensorData<int32_t>(op_context.perm);
// Determine the number of dimensions in the perm array
const int size = op_context.perm->dims->data[0];

// Prepare an params object to store the perm data whilst implementing
// the conversion
TransposeParams params;
params.perm_count = size;
for (int i = 0; i < size; ++i) {
params.perm[i] = perm_data[i];
}

#define TF_LITE_TRANSPOSE(type, scalar) \
type::Transpose(params, GetTensorShape(op_context.input), \
// Helper operation to acquire and convert data types
#define TF_LITE_TRANSPOSE(scalar) \
reference_ops::Transpose(params, GetTensorShape(op_context.input), \
GetTensorData<scalar>(op_context.input), \
GetTensorShape(op_context.output), \
GetTensorData<scalar>(op_context.output))

// Transpose kernel only does rearranging values not numeric evaluations on
// each cell. It's safe to implement per size of scalar type and this trick
// keeps the total code size in a reasonable range.
// Transpose really operates at the byte level,
// and therefore we only really need to get the
// size of the scalar datatype in bytes.
// Using this we can simplify the calls
// to only use a small number of data types
switch (op_context.input->type) {
case kTfLiteFloat32:
case kTfLiteInt32:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int32_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int32_t);
}
TF_LITE_TRANSPOSE(int32_t);
break;
case kTfLiteUInt8:
case kTfLiteInt8:
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
case kTfLiteUInt8:
TF_LITE_TRANSPOSE(int8_t);
break;
case kTfLiteInt16:
TF_LITE_TRANSPOSE(reference_ops, int16_t);
break;
case kTfLiteInt64:
TF_LITE_TRANSPOSE(reference_ops, int64_t);
break;
case kTfLiteBool:
if (sizeof(bool) == 1) {
if (kernel_type == kGenericOptimized) {
TF_LITE_TRANSPOSE(optimized_ops, int8_t);
} else {
TF_LITE_TRANSPOSE(reference_ops, int8_t);
}
} else {
TF_LITE_TRANSPOSE(reference_ops, bool);
}
TF_LITE_TRANSPOSE(int16_t);
break;
default:
TF_LITE_KERNEL_LOG(context,
"Type %s is currently not supported by Transpose.",
TfLiteTypeGetName(op_context.input->type));
return kTfLiteError;
}

#undef TF_LITE_TRANSPOSE

return kTfLiteOk;
}

} // namespace transpose

TfLiteRegistration* Register_TRANSPOSE_REF() {
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
transpose::Eval<transpose::kReference>};
return &r;
}

TfLiteRegistration* Register_TRANSPOSE_GENERIC_OPTIMIZED() {
static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
transpose::Eval<transpose::kGenericOptimized>};
return &r;
}

TfLiteRegistration* Register_TRANSPOSE() {
return Register_TRANSPOSE_GENERIC_OPTIMIZED();
} // namespace transpose

TfLiteRegistration Register_TRANSPOSE() {
return {/*init=*/nullptr,
/*free=*/nullptr,
/*prepare=*/Prepare,
/*invoke=*/Eval,
/*profiling_string=*/nullptr,
/*builtin_code=*/0,
/*custom_name=*/nullptr,
/*version=*/2};
}

} // namespace builtin
} // namespace ops
} // namespace tflite
} // namespace tflite
Loading