Skip to content

Commit

Permalink
Add 4-bit conv kernel support
Browse files Browse the repository at this point in the history
Add 4-bit conv kernel support

It is the open source equivalent of and already reviewed and approved http://cl/480909268.

BUG=b/248328557
  • Loading branch information
paulinesho committed Oct 14, 2022
1 parent 81f5208 commit df704e1
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 22 deletions.
55 changes: 40 additions & 15 deletions tensorflow/lite/micro/kernels/conv.cc
Expand Up @@ -17,13 +17,9 @@ limitations under the License.

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/kernel_util.h"
#include "tensorflow/lite/micro/micro_log.h"

Expand Down Expand Up @@ -57,7 +53,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_MSG(
context,
input->type == filter->type ||
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8),
(input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) ||
(input->type == kTfLiteInt8 && filter->type == kTfLiteInt4),
"Hybrid models are not supported on TFLite Micro.");

switch (input->type) { // Already know in/out types are same.
Expand Down Expand Up @@ -112,16 +109,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
break;
}
case kTfLiteInt8: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data), data.per_channel_output_multiplier,
data.per_channel_output_shift, tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
switch (filter->type) {
case kTfLiteInt4: {
int8_t* unpacked_filter_data = nullptr;
OpDataConv* op_data = static_cast<OpDataConv*>(node->user_data);
unpacked_filter_data = static_cast<int8_t*>(
context->GetScratchBuffer(context, op_data->filter_buffer_index));
reference_integer_ops::ConvPerChannelWithPackedInt4Weights(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
unpacked_filter_data, tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
}
case kTfLiteInt8: {
reference_integer_ops::ConvPerChannel(
ConvParamsQuantized(params, data),
data.per_channel_output_multiplier, data.per_channel_output_shift,
tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorData<int8_t>(input),
tflite::micro::GetTensorShape(filter),
tflite::micro::GetTensorData<int8_t>(filter),
tflite::micro::GetTensorShape(bias),
tflite::micro::GetOptionalTensorData<int32_t>(bias),
tflite::micro::GetTensorShape(output),
tflite::micro::GetTensorData<int8_t>(output));
break;
}
default:
MicroPrintf("Weight type %s (%d) not supported.",
TfLiteTypeGetName(filter->type), filter->type);
return kTfLiteError;
}
break;
}
default:
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/micro/kernels/conv.h
Expand Up @@ -45,6 +45,10 @@ struct OpDataConv {
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;

// A buffer used to store unpacked filter values. This is used if the source
// tensor is of n-bit precision that cannot be easily processed by kernels.
int filter_buffer_index;
};

extern const int kConvInputTensor;
Expand Down
15 changes: 10 additions & 5 deletions tensorflow/lite/micro/kernels/conv_common.cc
Expand Up @@ -14,12 +14,8 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/conv.h"
Expand Down Expand Up @@ -188,6 +184,15 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));

if (filter->type == kTfLiteInt4) {
int filter_size =
RuntimeShape(filter->dims->size,
reinterpret_cast<const int32_t*>(filter->dims->data))
.FlatSize();
context->RequestScratchBufferInArena(context, filter_size,
&data->filter_buffer_index);
}

micro_context->DeallocateTempTfLiteTensor(filter);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/memory_arena_threshold_test.cc
Expand Up @@ -98,7 +98,7 @@ constexpr int kTestConvModelOnlyTotalSize = 9488;
// TODO(b/207157610): replace magic number that depends on OPs
constexpr int kTestConvModelOnlyTailSize = 1744;
constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 128;
constexpr int kTestConvModelPersistentBufferDataSize = 680;
constexpr int kTestConvModelPersistentBufferDataSize = 712;
#else
// Total size contributed by the conv model excluding the
// RecordingMicroAllocator's overhead
Expand All @@ -109,7 +109,7 @@ constexpr int kTestConvModelOnlyTotalSize = 9760;
// TODO(b/207157610): replace magic number that depends on OPs
constexpr int kTestConvModelOnlyTailSize = 2016;
constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 224;
constexpr int kTestConvModelPersistentBufferDataSize = 680;
constexpr int kTestConvModelPersistentBufferDataSize = 704;
#endif
constexpr int kTestConvModelHeadSize = 7744;
constexpr int kTestConvModelOpRuntimeDataSize = 136;
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/lite/micro/memory_helpers.cc
Expand Up @@ -90,6 +90,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
case kTfLiteComplex128:
*size = sizeof(double) * 2;
break;
case kTfLiteInt4:
*size = sizeof(int8_t);
break;
default:
return kTfLiteError;
}
Expand Down

0 comments on commit df704e1

Please sign in to comment.