Skip to content

Commit bbd388b

Browse files
committed
Updates based on discussion with Nat.
1 parent 0a469be commit bbd388b

File tree

6 files changed

+156
-171
lines changed

6 files changed

+156
-171
lines changed

tensorflow/lite/micro/kernels/BUILD

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ cc_library(
4747

4848
cc_library(
4949
name = "fully_connected",
50-
srcs = select({
50+
srcs = [
51+
"fully_connected_common.cc",
52+
] + select({
5153
"//conditions:default": [
5254
"fully_connected.cc",
5355
],

tensorflow/lite/micro/kernels/cmsis-nn/fully_connected.cc

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,24 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
16+
#include "tensorflow/lite/micro/kernels/fully_connected.h"
1717

1818
#include "CMSIS/NN/Include/arm_nnfunctions.h"
1919
#include "tensorflow/lite/c/builtin_op_data.h"
2020
#include "tensorflow/lite/c/common.h"
2121
#include "tensorflow/lite/kernels/internal/common.h"
2222
#include "tensorflow/lite/kernels/internal/quantization_util.h"
23+
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
2324
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
2425
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
2526
#include "tensorflow/lite/kernels/kernel_util.h"
26-
#include "tensorflow/lite/micro/kernels/fully_connected.h"
2727
#include "tensorflow/lite/micro/kernels/kernel_util.h"
2828

2929
namespace tflite {
3030
namespace {
3131

3232
struct OpData {
33-
OpDataFullyConnectedReference reference_op_data;
33+
OpDataFullyConnected reference_op_data;
3434

3535
// Index to buffer for optimizations if applicable.
3636
int buffer_idx;
@@ -49,9 +49,9 @@ TfLiteStatus CalculateOpData(TfLiteContext* context,
4949
OpData* data) {
5050
// Set buffer index to a reset value
5151
data->buffer_idx = -1;
52-
return CalculateOpDataFullyConnectedReference(context, activation, data_type,
53-
input, filter, bias, output,
54-
&(data->reference_op_data));
52+
return CalculateOpDataFullyConnected(context, activation, data_type, input,
53+
filter, bias, output,
54+
&(data->reference_op_data));
5555
}
5656

5757
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -177,8 +177,16 @@ TfLiteStatus EvalQuantizedInt8(TfLiteContext* context, TfLiteNode* node,
177177
tflite::micro::GetTensorData<int8_t>(output)),
178178
ARM_MATH_SUCCESS);
179179
} else {
180-
return EvalQuantizedInt8FullyConnectedReference(
181-
context, node, data.reference_op_data, input, filter, bias, output);
180+
tflite::reference_integer_ops::FullyConnected(
181+
FullyConnectedParamsQuantized(data.reference_op_data),
182+
tflite::micro::GetTensorShape(input),
183+
tflite::micro::GetTensorData<int8_t>(input),
184+
tflite::micro::GetTensorShape(filter),
185+
tflite::micro::GetTensorData<int8_t>(filter),
186+
tflite::micro::GetTensorShape(bias),
187+
tflite::micro::GetTensorData<int32_t>(bias),
188+
tflite::micro::GetTensorShape(output),
189+
tflite::micro::GetTensorData<int8_t>(output));
182190
}
183191
return kTfLiteOk;
184192
}
@@ -202,21 +210,41 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
202210

203211
// Checks in Prepare ensure input, output and filter types are all the same.
204212
switch (input->type) {
205-
case kTfLiteFloat32:
206-
return EvalFloatFullyConnectedReference(context, node, params->activation,
207-
input, filter, bias, output);
208-
case kTfLiteInt8:
213+
case kTfLiteFloat32: {
214+
tflite::reference_ops::FullyConnected(
215+
FullyConnectedParamsFloat(params->activation),
216+
tflite::micro::GetTensorShape(input),
217+
tflite::micro::GetTensorData<float>(input),
218+
tflite::micro::GetTensorShape(filter),
219+
tflite::micro::GetTensorData<float>(filter),
220+
tflite::micro::GetTensorShape(bias),
221+
tflite::micro::GetTensorData<float>(bias),
222+
tflite::micro::GetTensorShape(output),
223+
tflite::micro::GetTensorData<float>(output));
224+
break;
225+
}
226+
case kTfLiteInt8: {
209227
return EvalQuantizedInt8(context, node, data, input, filter, bias,
210228
output);
211-
212-
case kTfLiteUInt8:
213-
return EvalQuantizedFullyConnectedReference(
214-
context, node, data.reference_op_data, input, filter, bias, output);
215-
216-
default:
229+
}
230+
case kTfLiteUInt8: {
231+
tflite::reference_ops::FullyConnected(
232+
FullyConnectedParamsQuantized(data.reference_op_data),
233+
tflite::micro::GetTensorShape(input),
234+
tflite::micro::GetTensorData<uint8_t>(input),
235+
tflite::micro::GetTensorShape(filter),
236+
tflite::micro::GetTensorData<uint8_t>(filter),
237+
tflite::micro::GetTensorShape(bias),
238+
tflite::micro::GetTensorData<int32_t>(bias),
239+
tflite::micro::GetTensorShape(output),
240+
tflite::micro::GetTensorData<uint8_t>(output));
241+
break;
242+
}
243+
default: {
217244
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
218245
TfLiteTypeGetName(input->type), input->type);
219246
return kTfLiteError;
247+
}
220248
}
221249
return kTfLiteOk;
222250
}

tensorflow/lite/micro/kernels/fully_connected.cc

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,32 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
16+
#include "tensorflow/lite/micro/kernels/fully_connected.h"
1717

1818
#include "tensorflow/lite/c/builtin_op_data.h"
1919
#include "tensorflow/lite/c/common.h"
2020
#include "tensorflow/lite/kernels/internal/common.h"
2121
#include "tensorflow/lite/kernels/internal/quantization_util.h"
22+
#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
2223
#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
2324
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
2425
#include "tensorflow/lite/kernels/kernel_util.h"
25-
#include "tensorflow/lite/micro/kernels/fully_connected.h"
2626
#include "tensorflow/lite/micro/kernels/kernel_util.h"
2727

2828
namespace tflite {
2929
namespace {
3030

31+
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
32+
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
33+
return context->AllocatePersistentBuffer(context,
34+
sizeof(OpDataFullyConnected));
35+
}
36+
3137
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
3238
TFLITE_DCHECK(node->user_data != nullptr);
3339
TFLITE_DCHECK(node->builtin_data != nullptr);
3440

35-
auto* data = static_cast<OpDataFullyConnectedReference*>(node->user_data);
41+
auto* data = static_cast<OpDataFullyConnected*>(node->user_data);
3642
const auto params =
3743
static_cast<const TfLiteFullyConnectedParams*>(node->builtin_data);
3844

@@ -51,9 +57,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
5157
TF_LITE_ENSURE_MSG(context, input->type == filter->type,
5258
"Hybrid models are not supported on TFLite Micro.");
5359

54-
return CalculateOpDataFullyConnectedReference(context, params->activation,
55-
input->type, input, filter,
56-
bias, output, data);
60+
return CalculateOpDataFullyConnected(context, params->activation, input->type,
61+
input, filter, bias, output, data);
5762
}
5863

5964
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
@@ -72,33 +77,64 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
7277

7378
TFLITE_DCHECK(node->user_data != nullptr);
7479
const auto& data =
75-
*(static_cast<const OpDataFullyConnectedReference*>(node->user_data));
80+
*(static_cast<const OpDataFullyConnected*>(node->user_data));
7681

7782
// Checks in Prepare ensure input, output and filter types are all the same.
7883
switch (input->type) {
79-
case kTfLiteFloat32:
80-
return EvalFloatFullyConnectedReference(context, node, params->activation,
81-
input, filter, bias, output);
82-
case kTfLiteInt8:
83-
return EvalQuantizedInt8FullyConnectedReference(
84-
context, node, data, input, filter, bias, output);
85-
86-
case kTfLiteUInt8:
87-
return EvalQuantizedFullyConnectedReference(context, node, data, input,
88-
filter, bias, output);
89-
90-
default:
84+
case kTfLiteFloat32: {
85+
tflite::reference_ops::FullyConnected(
86+
FullyConnectedParamsFloat(params->activation),
87+
tflite::micro::GetTensorShape(input),
88+
tflite::micro::GetTensorData<float>(input),
89+
tflite::micro::GetTensorShape(filter),
90+
tflite::micro::GetTensorData<float>(filter),
91+
tflite::micro::GetTensorShape(bias),
92+
tflite::micro::GetTensorData<float>(bias),
93+
tflite::micro::GetTensorShape(output),
94+
tflite::micro::GetTensorData<float>(output));
95+
break;
96+
}
97+
98+
case kTfLiteInt8: {
99+
tflite::reference_integer_ops::FullyConnected(
100+
FullyConnectedParamsQuantized(data),
101+
tflite::micro::GetTensorShape(input),
102+
tflite::micro::GetTensorData<int8_t>(input),
103+
tflite::micro::GetTensorShape(filter),
104+
tflite::micro::GetTensorData<int8_t>(filter),
105+
tflite::micro::GetTensorShape(bias),
106+
tflite::micro::GetTensorData<int32_t>(bias),
107+
tflite::micro::GetTensorShape(output),
108+
tflite::micro::GetTensorData<int8_t>(output));
109+
break;
110+
}
111+
112+
case kTfLiteUInt8: {
113+
tflite::reference_ops::FullyConnected(
114+
FullyConnectedParamsQuantized(data),
115+
tflite::micro::GetTensorShape(input),
116+
tflite::micro::GetTensorData<uint8_t>(input),
117+
tflite::micro::GetTensorShape(filter),
118+
tflite::micro::GetTensorData<uint8_t>(filter),
119+
tflite::micro::GetTensorShape(bias),
120+
tflite::micro::GetTensorData<int32_t>(bias),
121+
tflite::micro::GetTensorShape(output),
122+
tflite::micro::GetTensorData<uint8_t>(output));
123+
break;
124+
}
125+
default: {
91126
TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.",
92127
TfLiteTypeGetName(input->type), input->type);
93128
return kTfLiteError;
129+
}
94130
}
95131
return kTfLiteOk;
96132
}
97133

98134
} // namespace
99135

100136
TfLiteRegistration Register_FULLY_CONNECTED() {
101-
return {/*init=*/InitFullyConnectedReference,
137+
return {/*init=*/Init,
102138
/*free=*/nullptr,
103139
/*prepare=*/Prepare,
104140
/*invoke=*/Eval,

tensorflow/lite/micro/kernels/fully_connected.h

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ limitations under the License.
2323

2424
namespace tflite {
2525

26-
struct OpDataFullyConnectedReference {
26+
struct OpDataFullyConnected {
2727
// The scaling factor from input to output (aka the 'real multiplier') can
2828
// be represented as a fixed point multiplier plus a left shift.
2929
int32_t output_multiplier;
@@ -38,52 +38,27 @@ struct OpDataFullyConnectedReference {
3838
int32_t input_zero_point;
3939
int32_t filter_zero_point;
4040
int32_t output_zero_point;
41-
42-
// Returns a FullyConnectedParams struct with all the parameters needed for a
43-
// quantized fully connected computation.
44-
FullyConnectedParams ToQuantizedParams() const {
45-
FullyConnectedParams op_params;
46-
op_params.input_offset = -input_zero_point;
47-
op_params.weights_offset = -filter_zero_point;
48-
op_params.output_offset = output_zero_point;
49-
op_params.output_multiplier = output_multiplier;
50-
op_params.output_shift = output_shift;
51-
op_params.quantized_activation_min = output_activation_min;
52-
op_params.quantized_activation_max = output_activation_max;
53-
return op_params;
54-
}
5541
};
5642

5743
extern const int kFullyConnectedInputTensor;
5844
extern const int kFullyConnectedWeightsTensor;
5945
extern const int kFullyConnectedBiasTensor;
6046
extern const int kFullyConnectedOutputTensor;
6147

62-
TfLiteStatus CalculateOpDataFullyConnectedReference(
48+
// Returns a FullyConnectedParams struct with all the parameters needed for a
49+
// float computation.
50+
FullyConnectedParams FullyConnectedParamsFloat(
51+
TfLiteFusedActivation activation);
52+
53+
// Returns a FullyConnectedParams struct with all the parameters needed for a
54+
// quantized computation.
55+
FullyConnectedParams FullyConnectedParamsQuantized(
56+
const OpDataFullyConnected& op_data);
57+
58+
TfLiteStatus CalculateOpDataFullyConnected(
6359
TfLiteContext* context, TfLiteFusedActivation activation,
6460
TfLiteType data_type, const TfLiteTensor* input, const TfLiteTensor* filter,
65-
const TfLiteTensor* bias, TfLiteTensor* output,
66-
OpDataFullyConnectedReference* data);
67-
68-
void* InitFullyConnectedReference(TfLiteContext* context, const char* buffer,
69-
size_t length);
70-
71-
TfLiteStatus EvalFloatFullyConnectedReference(
72-
TfLiteContext* context, TfLiteNode* node, TfLiteFusedActivation activation,
73-
const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter,
74-
const TfLiteEvalTensor* bias, TfLiteEvalTensor* output);
75-
76-
TfLiteStatus EvalQuantizedInt8FullyConnectedReference(
77-
TfLiteContext* context, TfLiteNode* node,
78-
const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input,
79-
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
80-
TfLiteEvalTensor* output);
81-
82-
TfLiteStatus EvalQuantizedFullyConnectedReference(
83-
TfLiteContext* context, TfLiteNode* node,
84-
const OpDataFullyConnectedReference& data, const TfLiteEvalTensor* input,
85-
const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias,
86-
TfLiteEvalTensor* output);
61+
const TfLiteTensor* bias, TfLiteTensor* output, OpDataFullyConnected* data);
8762

8863
// This is the most generic TfLiteRegistration. The actual supported types may
8964
// still be target dependent. The only requirement is that every implementation

0 commit comments

Comments
 (0)