From 22453e13c0d6ac6d273926209905c73675f10483 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 21 Mar 2019 10:27:43 +0530 Subject: [PATCH] Lite: Fully_connected Op code refactored --- tensorflow/lite/kernels/fully_connected.cc | 53 +++++++--------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc index 55cde983abccdd..a07ed3ca61bafb 100644 --- a/tensorflow/lite/kernels/fully_connected.cc +++ b/tensorflow/lite/kernels/fully_connected.cc @@ -277,17 +277,6 @@ TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -#define TF_LITE_MACRO_DISPATCH(macro_name, params, target_namespace) \ - if (params->activation == kTfLiteActNone) { \ - macro_name(target_namespace, kNone); \ - } \ - if (params->activation == kTfLiteActRelu) { \ - macro_name(target_namespace, kRelu); \ - } \ - if (params->activation == kTfLiteActRelu6) { \ - macro_name(target_namespace, kRelu6); \ - } - namespace { void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, const TfLiteTensor* bias, @@ -343,38 +332,29 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/1); return EvalHybrid(context, node, params, data, input, filter, bias, input_quantized, scaling_factors, output); - } else if (kernel_type == kReference) { - switch (output->type) { - case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); - break; - case kTfLiteInt8: - FullyConnectedInt8(data, input, filter, bias, output, gemm_context); - break; - case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); - break; - default: - context->ReportError( - context, - "Quantized FullyConnected expects output data type uint8 or int16"); - return kTfLiteError; - } } else { switch (output->type) { case kTfLiteUInt8: - TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + } break; case kTfLiteInt8: FullyConnectedInt8(data, input, filter, bias, output, gemm_context); break; case kTfLiteInt16: - TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + if (kernel_type == kReference) { + TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + } else { + TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + } break; default: - context->ReportError( - context, - "Quantized FullyConnected expects output data type uint8 or int16"); + context->ReportError(context, + "Quantized FullyConnected expects output data " + "type uint8, int8 or int16"); return kTfLiteError; } } @@ -457,8 +437,6 @@ TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, return kTfLiteOk; } -#undef TF_LITE_MACRO_DISPATCH - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = @@ -501,8 +479,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return kTfLiteError; } default: - context->ReportError(context, "Type %d not currently supported.", - filter->type); + context->ReportError(context, + "Filter data type %s currently not supported.", + TfLiteTypeGetName(filter->type)); return kTfLiteError; } return kTfLiteOk;