From df704e13533ec9bd8622986e77e0eb7ff052e5cc Mon Sep 17 00:00:00 2001 From: Pauline Sho Date: Fri, 14 Oct 2022 12:13:47 -0400 Subject: [PATCH] Add 4-bit conv kernel support Add 4-bit conv kernel support It is the open source equivalent of and already reviewed and approved http://cl/480909268. BUG=b/248328557 --- tensorflow/lite/micro/kernels/conv.cc | 55 ++++++++++++++----- tensorflow/lite/micro/kernels/conv.h | 4 ++ tensorflow/lite/micro/kernels/conv_common.cc | 15 +++-- .../lite/micro/memory_arena_threshold_test.cc | 4 +- tensorflow/lite/micro/memory_helpers.cc | 3 + 5 files changed, 59 insertions(+), 22 deletions(-) diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index f609a008e4..36c57655a1 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -17,13 +17,9 @@ limitations under the License. #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" #include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" -#include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/kernel_util.h" #include "tensorflow/lite/micro/micro_log.h" @@ -57,7 +53,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_MSG( context, input->type == filter->type || - (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8), + (input->type == kTfLiteInt16 && filter->type == kTfLiteInt8) || + (input->type == kTfLiteInt8 && filter->type == kTfLiteInt4), "Hybrid models are not supported on TFLite Micro."); switch (input->type) { // Already know in/out types are same. @@ -112,16 +109,44 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { break; } case kTfLiteInt8: { - reference_integer_ops::ConvPerChannel( - ConvParamsQuantized(params, data), data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetOptionalTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); + switch (filter->type) { + case kTfLiteInt4: { + int8_t* unpacked_filter_data = nullptr; + OpDataConv* op_data = static_cast(node->user_data); + unpacked_filter_data = static_cast( + context->GetScratchBuffer(context, op_data->filter_buffer_index)); + reference_integer_ops::ConvPerChannelWithPackedInt4Weights( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + unpacked_filter_data, tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + case kTfLiteInt8: { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), + data.per_channel_output_multiplier, data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetOptionalTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); + break; + } + default: + MicroPrintf("Weight type %s (%d) not supported.", + TfLiteTypeGetName(filter->type), filter->type); + return kTfLiteError; + } break; } default: diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h index 06b35e1e55..06e9db43a0 100644 --- a/tensorflow/lite/micro/kernels/conv.h +++ b/tensorflow/lite/micro/kernels/conv.h @@ -45,6 +45,10 @@ struct OpDataConv { // uint8_t these would be 0 and 255. int32_t output_activation_min; int32_t output_activation_max; + + // A buffer used to store unpacked filter values. This is used if the source + // tensor is of n-bit precision that cannot be easily processed by kernels. + int filter_buffer_index; }; extern const int kConvInputTensor; diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc index 7115f7babe..c548c932c2 100644 --- a/tensorflow/lite/micro/kernels/conv_common.cc +++ b/tensorflow/lite/micro/kernels/conv_common.cc @@ -14,12 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/c_api_types.h" #include "tensorflow/lite/c/common.h" -#include "tensorflow/lite/kernels/internal/common.h" -#include "tensorflow/lite/kernels/internal/quantization_util.h" -#include "tensorflow/lite/kernels/internal/reference/conv.h" -#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" -#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" #include "tensorflow/lite/kernels/padding.h" #include "tensorflow/lite/micro/kernels/conv.h" @@ -188,6 +184,15 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) { context, node, params, input_width, input_height, filter_width, filter_height, output_width, output_height, input->type, data)); + if (filter->type == kTfLiteInt4) { + int filter_size = + RuntimeShape(filter->dims->size, + reinterpret_cast(filter->dims->data)) + .FlatSize(); + context->RequestScratchBufferInArena(context, filter_size, + &data->filter_buffer_index); + } + micro_context->DeallocateTempTfLiteTensor(filter); micro_context->DeallocateTempTfLiteTensor(input); micro_context->DeallocateTempTfLiteTensor(output); diff --git a/tensorflow/lite/micro/memory_arena_threshold_test.cc b/tensorflow/lite/micro/memory_arena_threshold_test.cc index 57e556378c..1d664237ff 100644 --- a/tensorflow/lite/micro/memory_arena_threshold_test.cc +++ b/tensorflow/lite/micro/memory_arena_threshold_test.cc @@ -98,7 +98,7 @@ constexpr int kTestConvModelOnlyTotalSize = 9488; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kTestConvModelOnlyTailSize = 1744; constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 128; -constexpr int kTestConvModelPersistentBufferDataSize = 680; +constexpr int kTestConvModelPersistentBufferDataSize = 712; #else // Total size contributed by the conv model excluding the // RecordingMicroAllocator's overhead @@ -109,7 +109,7 @@ constexpr int kTestConvModelOnlyTotalSize = 9760; // TODO(b/207157610): replace magic number that depends on OPs constexpr int kTestConvModelOnlyTailSize = 2016; constexpr int kTestConvModelPersistentTfLiteTensorDataSize = 224; -constexpr int kTestConvModelPersistentBufferDataSize = 680; +constexpr int kTestConvModelPersistentBufferDataSize = 704; #endif constexpr int kTestConvModelHeadSize = 7744; constexpr int kTestConvModelOpRuntimeDataSize = 136; diff --git a/tensorflow/lite/micro/memory_helpers.cc b/tensorflow/lite/micro/memory_helpers.cc index 32a661dd26..f041d37dfd 100644 --- a/tensorflow/lite/micro/memory_helpers.cc +++ b/tensorflow/lite/micro/memory_helpers.cc @@ -90,6 +90,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) { case kTfLiteComplex128: *size = sizeof(double) * 2; break; + case kTfLiteInt4: + *size = sizeof(int8_t); + break; default: return kTfLiteError; }