@@ -13,26 +13,32 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- #include " tensorflow/lite/kernels/internal/reference /fully_connected.h"
16+ #include " tensorflow/lite/micro/kernels /fully_connected.h"
1717
1818#include " tensorflow/lite/c/builtin_op_data.h"
1919#include " tensorflow/lite/c/common.h"
2020#include " tensorflow/lite/kernels/internal/common.h"
2121#include " tensorflow/lite/kernels/internal/quantization_util.h"
22+ #include " tensorflow/lite/kernels/internal/reference/fully_connected.h"
2223#include " tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
2324#include " tensorflow/lite/kernels/internal/tensor_ctypes.h"
2425#include " tensorflow/lite/kernels/kernel_util.h"
25- #include " tensorflow/lite/micro/kernels/fully_connected.h"
2626#include " tensorflow/lite/micro/kernels/kernel_util.h"
2727
2828namespace tflite {
2929namespace {
3030
31+ void * Init (TfLiteContext* context, const char * buffer, size_t length) {
32+ TFLITE_DCHECK (context->AllocatePersistentBuffer != nullptr );
33+ return context->AllocatePersistentBuffer (context,
34+ sizeof (OpDataFullyConnected));
35+ }
36+
3137TfLiteStatus Prepare (TfLiteContext* context, TfLiteNode* node) {
3238 TFLITE_DCHECK (node->user_data != nullptr );
3339 TFLITE_DCHECK (node->builtin_data != nullptr );
3440
35- auto * data = static_cast <OpDataFullyConnectedReference *>(node->user_data );
41+ auto * data = static_cast <OpDataFullyConnected *>(node->user_data );
3642 const auto params =
3743 static_cast <const TfLiteFullyConnectedParams*>(node->builtin_data );
3844
@@ -51,9 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
5157 TF_LITE_ENSURE_MSG (context, input->type == filter->type ,
5258 " Hybrid models are not supported on TFLite Micro." );
5359
54- return CalculateOpDataFullyConnectedReference (context, params->activation ,
55- input->type , input, filter,
56- bias, output, data);
60+ return CalculateOpDataFullyConnected (context, params->activation , input->type ,
61+ input, filter, bias, output, data);
5762}
5863
5964TfLiteStatus Eval (TfLiteContext* context, TfLiteNode* node) {
@@ -72,33 +77,64 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
7277
7378 TFLITE_DCHECK (node->user_data != nullptr );
7479 const auto & data =
75- *(static_cast <const OpDataFullyConnectedReference *>(node->user_data ));
80+ *(static_cast <const OpDataFullyConnected *>(node->user_data ));
7681
7782 // Checks in Prepare ensure input, output and filter types are all the same.
7883 switch (input->type ) {
79- case kTfLiteFloat32 :
80- return EvalFloatFullyConnectedReference (context, node, params->activation ,
81- input, filter, bias, output);
82- case kTfLiteInt8 :
83- return EvalQuantizedInt8FullyConnectedReference (
84- context, node, data, input, filter, bias, output);
85-
86- case kTfLiteUInt8 :
87- return EvalQuantizedFullyConnectedReference (context, node, data, input,
88- filter, bias, output);
89-
90- default :
84+ case kTfLiteFloat32 : {
85+ tflite::reference_ops::FullyConnected (
86+ FullyConnectedParamsFloat (params->activation ),
87+ tflite::micro::GetTensorShape (input),
88+ tflite::micro::GetTensorData<float >(input),
89+ tflite::micro::GetTensorShape (filter),
90+ tflite::micro::GetTensorData<float >(filter),
91+ tflite::micro::GetTensorShape (bias),
92+ tflite::micro::GetTensorData<float >(bias),
93+ tflite::micro::GetTensorShape (output),
94+ tflite::micro::GetTensorData<float >(output));
95+ break ;
96+ }
97+
98+ case kTfLiteInt8 : {
99+ tflite::reference_integer_ops::FullyConnected (
100+ FullyConnectedParamsQuantized (data),
101+ tflite::micro::GetTensorShape (input),
102+ tflite::micro::GetTensorData<int8_t >(input),
103+ tflite::micro::GetTensorShape (filter),
104+ tflite::micro::GetTensorData<int8_t >(filter),
105+ tflite::micro::GetTensorShape (bias),
106+ tflite::micro::GetTensorData<int32_t >(bias),
107+ tflite::micro::GetTensorShape (output),
108+ tflite::micro::GetTensorData<int8_t >(output));
109+ break ;
110+ }
111+
112+ case kTfLiteUInt8 : {
113+ tflite::reference_ops::FullyConnected (
114+ FullyConnectedParamsQuantized (data),
115+ tflite::micro::GetTensorShape (input),
116+ tflite::micro::GetTensorData<uint8_t >(input),
117+ tflite::micro::GetTensorShape (filter),
118+ tflite::micro::GetTensorData<uint8_t >(filter),
119+ tflite::micro::GetTensorShape (bias),
120+ tflite::micro::GetTensorData<int32_t >(bias),
121+ tflite::micro::GetTensorShape (output),
122+ tflite::micro::GetTensorData<uint8_t >(output));
123+ break ;
124+ }
125+ default : {
91126 TF_LITE_KERNEL_LOG (context, " Type %s (%d) not supported." ,
92127 TfLiteTypeGetName (input->type ), input->type );
93128 return kTfLiteError ;
129+ }
94130 }
95131 return kTfLiteOk ;
96132}
97133
98134} // namespace
99135
100136TfLiteRegistration Register_FULLY_CONNECTED () {
101- return {/* init=*/ InitFullyConnectedReference ,
137+ return {/* init=*/ Init ,
102138 /* free=*/ nullptr ,
103139 /* prepare=*/ Prepare,
104140 /* invoke=*/ Eval,
0 commit comments