diff --git a/tensorflow/lite/micro/BUILD b/tensorflow/lite/micro/BUILD index f55c5bf869053a..1a20c7b28216b8 100644 --- a/tensorflow/lite/micro/BUILD +++ b/tensorflow/lite/micro/BUILD @@ -107,6 +107,7 @@ cc_library( "//tensorflow/lite/core/api", "//tensorflow/lite/kernels:op_macros", "//tensorflow/lite/kernels/internal:compatibility", + "//tensorflow/lite/micro/kernels:conv", "//tensorflow/lite/micro/kernels:ethosu", "//tensorflow/lite/micro/kernels:fully_connected", "//tensorflow/lite/micro/kernels:micro_ops", diff --git a/tensorflow/lite/micro/kernels/BUILD b/tensorflow/lite/micro/kernels/BUILD index 8ec2b84f8f17a5..2a0a2d9e58da1d 100644 --- a/tensorflow/lite/micro/kernels/BUILD +++ b/tensorflow/lite/micro/kernels/BUILD @@ -36,6 +36,43 @@ cc_library( ], ) +cc_library( + name = "conv", + srcs = [ + "conv_common.cc", + ] + select({ + "//conditions:default": [ + "conv.cc", + ], + ":xtensa_hifimini": [ + "xtensa/conv.cc", + ], + }), + hdrs = ["conv.h"], + copts = micro_copts(), + visibility = [ + # Kernel variants need to be visible to the examples and benchmarks. + ":micro", + ], + deps = [ + ":fixedpoint_utils", + ":kernel_util", + ":xtensa", + "//tensorflow/lite/c:common", + "//tensorflow/lite/kernels/internal:common", + "//tensorflow/lite/kernels/internal:quantization_util", + "//tensorflow/lite/kernels/internal:reference_base", + "//tensorflow/lite/kernels/internal:tensor", + "//tensorflow/lite/kernels:kernel_util", + "//tensorflow/lite/kernels:padding", + ] + select({ + "//conditions:default": [], + ":xtensa_hifimini": [ + #"//third_party/xtensa/cstub64s:hifi_mini", + ], + }), +) + cc_library( name = "conv_test_common", srcs = [ @@ -211,14 +248,12 @@ cc_library( "zeros_like.cc", ] + select({ "//conditions:default": [ - "conv.cc", "depthwise_conv.cc", "quantize.cc", "softmax.cc", "svdf.cc", ], ":xtensa_hifimini": [ - "xtensa/conv.cc", "xtensa/depthwise_conv.cc", "xtensa/quantize.cc", "xtensa/softmax.cc", @@ -895,6 +930,7 @@ cc_test( ":kernel_runner", "//tensorflow/lite/c:common", "//tensorflow/lite/micro:micro_utils", + "//tensorflow/lite/micro:op_resolvers", "//tensorflow/lite/micro:test_helpers", "//tensorflow/lite/micro/testing:micro_test", ], diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc index 3691cd90224b8e..e82fee667f1171 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/conv.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "CMSIS/NN/Include/arm_nn_types.h" #include "CMSIS/NN/Include/arm_nnfunctions.h" @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -30,90 +31,13 @@ limitations under the License. namespace tflite { namespace { -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - struct OpData { - TfLitePaddingValues padding; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Per channel output multiplier and shift. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; + OpDataConv reference_op_data; // Index to buffer for optimizations if applicable. int buffer_idx; }; -inline PaddingType RuntimePaddingType(TfLitePadding padding) { - switch (padding) { - case TfLitePadding::kTfLitePaddingSame: - return PaddingType::kSame; - case TfLitePadding::kTfLitePaddingValid: - return PaddingType::kValid; - case TfLitePadding::kTfLitePaddingUnknown: - default: - return PaddingType::kNone; - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - const TfLiteConvParams* params, int width, - int height, int filter_width, int filter_height, - int out_width, int out_height, - const TfLiteType data_type, OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - int num_channels = filter->dims->data[kConvQuantizedDimension]; - - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast(data->per_channel_output_shift), num_channels)); - } - return kTfLiteOk; -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); return context->AllocatePersistentBuffer(context, sizeof(OpData)); @@ -124,12 +48,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->builtin_data != nullptr); int32_t buf_size = 0; - const auto params = static_cast(node->builtin_data); + const auto& params = + *(static_cast(node->builtin_data)); OpData* data = static_cast(node->user_data); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const TfLiteTensor* input = GetInput(context, node, kConvInputTensor); + const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor); + const TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor); RuntimeShape input_shape = GetTensorShape(input); RuntimeShape output_shape = GetTensorShape(output); @@ -161,34 +86,31 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // non-int8 cases. Protect this section with a if (input->type == kTfLiteInt8) // when the issue is fixed. const int num_channels = filter->dims->data[kConvQuantizedDimension]; - data->per_channel_output_multiplier = + data->reference_op_data.per_channel_output_multiplier = static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); - data->per_channel_output_shift = + data->reference_op_data.per_channel_output_shift = static_cast(context->AllocatePersistentBuffer( context, num_channels * sizeof(int32_t))); - TF_LITE_ENSURE_STATUS(CalculateOpData( + TF_LITE_ENSURE_STATUS(CalculateOpDataConv( context, node, params, input_dims.w, input_dims.h, filter_dims.w, - filter_dims.h, output_dims.w, output_dims.h, input->type, data)); - - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; + filter_dims.h, output_dims.w, output_dims.h, input->type, + &data->reference_op_data)); if (input->type == kTfLiteInt8) { // Initialize cmsis_nn convolution parameters cmsis_nn_conv_params conv_params; conv_params.input_offset = -input->params.zero_point; conv_params.output_offset = output->params.zero_point; - conv_params.stride.h = params->stride_height; - conv_params.stride.w = params->stride_width; - conv_params.dilation.h = params->dilation_height_factor; - conv_params.dilation.w = params->dilation_width_factor; - conv_params.padding.h = data->padding.height; - conv_params.padding.w = data->padding.width; - conv_params.activation.min = data->output_activation_min; - conv_params.activation.max = data->output_activation_max; + conv_params.stride.h = params.stride_height; + conv_params.stride.w = params.stride_width; + conv_params.dilation.h = params.dilation_height_factor; + conv_params.dilation.w = params.dilation_width_factor; + conv_params.padding.h = data->reference_op_data.padding.height; + conv_params.padding.w = data->reference_op_data.padding.width; + conv_params.activation.min = data->reference_op_data.output_activation_min; + conv_params.activation.max = data->reference_op_data.output_activation_max; buf_size = arm_convolve_wrapper_s8_get_buffer_size( &conv_params, &input_dims, &filter_dims, &output_dims); @@ -203,73 +125,34 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return kTfLiteOk; } -TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* im2col, - TfLiteEvalTensor* hwcn_weights, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data.padding.width; - op_params.padding_values.height = data.padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(im2col), - tflite::micro::GetTensorData(im2col), nullptr); - return kTfLiteOk; -} - TfLiteStatus EvalQuantizedPerChannel( - TfLiteContext* context, TfLiteNode* node, TfLiteConvParams* params, + TfLiteContext* context, TfLiteNode* node, const TfLiteConvParams& params, const OpData& data, const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, TfLiteEvalTensor* output, TfLiteEvalTensor* im2col) { cmsis_nn_conv_params conv_params; - conv_params.dilation.h = params->dilation_height_factor; - conv_params.dilation.w = params->dilation_width_factor; + conv_params.dilation.h = params.dilation_height_factor; + conv_params.dilation.w = params.dilation_width_factor; // TODO(#43557) Remove checks for dilation and call to reference // implementation when dilation is supported in the optimized implementation // by CMSIS-NN. if (conv_params.dilation.h == 1 && conv_params.dilation.w == 1) { // Initialize cmsis_nn convolution parameters - conv_params.input_offset = -data.input_zero_point; - conv_params.output_offset = data.output_zero_point; - conv_params.stride.h = params->stride_height; - conv_params.stride.w = params->stride_width; - conv_params.padding.h = data.padding.height; - conv_params.padding.w = data.padding.width; - conv_params.activation.min = data.output_activation_min; - conv_params.activation.max = data.output_activation_max; + conv_params.input_offset = -data.reference_op_data.input_zero_point; + conv_params.output_offset = data.reference_op_data.output_zero_point; + conv_params.stride.h = params.stride_height; + conv_params.stride.w = params.stride_width; + conv_params.padding.h = data.reference_op_data.padding.height; + conv_params.padding.w = data.reference_op_data.padding.width; + conv_params.activation.min = data.reference_op_data.output_activation_min; + conv_params.activation.max = data.reference_op_data.output_activation_max; // Initialize cmsis_nn per channel quantization parameters cmsis_nn_per_channel_quant_params quant_params; - quant_params.multiplier = - const_cast(data.per_channel_output_multiplier); - quant_params.shift = const_cast(data.per_channel_output_shift); + quant_params.multiplier = const_cast( + data.reference_op_data.per_channel_output_multiplier); + quant_params.shift = + const_cast(data.reference_op_data.per_channel_output_shift); RuntimeShape filter_shape = tflite::micro::GetTensorShape(filter); RuntimeShape input_shape = tflite::micro::GetTensorShape(input); @@ -340,22 +223,11 @@ TfLiteStatus EvalQuantizedPerChannel( tflite::micro::GetTensorData(output)), ARM_MATH_SUCCESS); } else { - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data.padding.height; - op_params.padding_values.width = data.padding.width; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - reference_integer_ops::ConvPerChannel( - op_params, data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), + ConvParamsQuantized(params, data.reference_op_data), + data.reference_op_data.per_channel_output_multiplier, + data.reference_op_data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -367,54 +239,20 @@ TfLiteStatus EvalQuantizedPerChannel( return kTfLiteOk; } -TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col, - TfLiteEvalTensor* hwcn_weights, - TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data.padding.width; - op_params.padding_values.height = data.padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - - reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(im2col), - tflite::micro::GetTensorData(im2col)); - return kTfLiteOk; -} - TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); + const auto& params = + *(reinterpret_cast(node->builtin_data)); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kConvInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); const TfLiteEvalTensor* bias = (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kConvOutputTensor); TFLITE_DCHECK(node->user_data != nullptr); const OpData& data = *(static_cast(node->user_data)); @@ -424,18 +262,38 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { "Hybrid models are not supported on TFLite Micro."); switch (input->type) { // Already know in/out types are same. - case kTfLiteFloat32: - EvalFloat(context, node, params, data, input, filter, bias, nullptr, - nullptr, output); + case kTfLiteFloat32: { + tflite::reference_ops::Conv( + ConvParamsFloat(params, data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr); break; + } case kTfLiteInt8: return EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, output, nullptr); break; - case kTfLiteUInt8: - return EvalQuantized(context, node, params, data, input, filter, bias, - nullptr, nullptr, output); + case kTfLiteUInt8: { + reference_ops::Conv(ConvParamsQuantized(params, data.reference_op_data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr, + nullptr); break; + } default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); diff --git a/tensorflow/lite/micro/kernels/conv.cc b/tensorflow/lite/micro/kernels/conv.cc index dc821df5418d24..4530f941b926f1 100644 --- a/tensorflow/lite/micro/kernels/conv.cc +++ b/tensorflow/lite/micro/kernels/conv.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -28,294 +29,74 @@ limitations under the License. namespace tflite { namespace { -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - -// This file has 2 implementation of Conv. - -struct OpData { - TfLitePaddingValues padding; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t filter_zero_point; - int32_t output_zero_point; - - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Per channel output multiplier and shift. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - -inline PaddingType RuntimePaddingType(TfLitePadding padding) { - switch (padding) { - case TfLitePadding::kTfLitePaddingSame: - return PaddingType::kSame; - case TfLitePadding::kTfLitePaddingValid: - return PaddingType::kValid; - case TfLitePadding::kTfLitePaddingUnknown: - default: - return PaddingType::kNone; - } -} - -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - const TfLiteConvParams* params, int width, - int height, int filter_width, int filter_height, - int out_width, int out_height, - const TfLiteType data_type, OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input != nullptr); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - TF_LITE_ENSURE(context, filter != nullptr); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE(context, output != nullptr); - int output_channels = filter->dims->data[kConvQuantizedDimension]; - - TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast(data->per_channel_output_shift), - output_channels)); - } - return kTfLiteOk; -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - - OpData* data = static_cast(node->user_data); - const auto params = static_cast(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE(context, output != nullptr); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - TF_LITE_ENSURE(context, input != nullptr); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - TF_LITE_ENSURE(context, filter != nullptr); - - int input_width = input->dims->data[2]; - int input_height = input->dims->data[1]; - int filter_width = filter->dims->data[2]; - int filter_height = filter->dims->data[1]; - int output_width = output->dims->data[2]; - int output_height = output->dims->data[1]; - - // Dynamically allocate per-channel quantization parameters. - const int num_channels = filter->dims->data[kConvQuantizedDimension]; - data->per_channel_output_multiplier = - static_cast(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - data->per_channel_output_shift = - static_cast(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - static_cast(filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - - TF_LITE_ENSURE(context, - affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kConvQuantizedDimension]); - TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, - affine_quantization->zero_point->size); - } - - TF_LITE_ENSURE_STATUS(CalculateOpData( - context, node, params, input_width, input_height, filter_width, - filter_height, output_width, output_height, input->type, data)); - - data->input_zero_point = input->params.zero_point; - data->filter_zero_point = filter->params.zero_point; - data->output_zero_point = output->params.zero_point; - - return kTfLiteOk; -} // namespace conv - -void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, const TfLiteEvalTensor* bias, - TfLiteEvalTensor* im2col, TfLiteEvalTensor* hwcn_weights, - TfLiteEvalTensor* output) { - const int32_t input_offset = -data.input_zero_point; - const int32_t filter_offset = -data.filter_zero_point; - const int32_t output_offset = data.output_zero_point; - - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data.padding.width; - op_params.padding_values.height = data.padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.input_offset = input_offset; - op_params.weights_offset = filter_offset; - op_params.output_offset = output_offset; - op_params.output_multiplier = data.output_multiplier; - op_params.output_shift = -data.output_shift; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(im2col), - tflite::micro::GetTensorData(im2col), nullptr); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, const OpData& data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output, - TfLiteEvalTensor* im2col) { - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.input_offset = -data.input_zero_point; - op_params.output_offset = data.output_zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data.padding.height; - op_params.padding_values.width = data.padding.width; - op_params.quantized_activation_min = data.output_activation_min; - op_params.quantized_activation_max = data.output_activation_max; - - reference_integer_ops::ConvPerChannel( - op_params, data.per_channel_output_multiplier, - data.per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); -} - -void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, const OpData& data, - const TfLiteEvalTensor* input, const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, TfLiteEvalTensor* im2col, - TfLiteEvalTensor* hwcn_weights, TfLiteEvalTensor* output) { - float output_activation_min, output_activation_max; - CalculateActivationRange(params->activation, &output_activation_min, - &output_activation_max); - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.padding_type = RuntimePaddingType(params->padding); - op_params.padding_values.width = data.padding.width; - op_params.padding_values.height = data.padding.height; - op_params.stride_width = params->stride_width; - op_params.stride_height = params->stride_height; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.float_activation_min = output_activation_min; - op_params.float_activation_max = output_activation_max; - - reference_ops::Conv(op_params, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output), - tflite::micro::GetTensorShape(im2col), - tflite::micro::GetTensorData(im2col)); + return context->AllocatePersistentBuffer(context, sizeof(OpDataConv)); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - auto* params = reinterpret_cast(node->builtin_data); - const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kConvInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); const TfLiteEvalTensor* bias = (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kConvOutputTensor); + TFLITE_DCHECK(node->builtin_data != nullptr); + const auto& params = + *(reinterpret_cast(node->builtin_data)); TFLITE_DCHECK(node->user_data != nullptr); - const OpData& data = *(static_cast(node->user_data)); + const auto& data = *(static_cast(node->user_data)); TF_LITE_ENSURE_EQ(context, input->type, output->type); TF_LITE_ENSURE_MSG(context, input->type == filter->type, "Hybrid models are not supported on TFLite Micro."); switch (input->type) { // Already know in/out types are same. - case kTfLiteFloat32: - EvalFloat(context, node, params, data, input, filter, bias, nullptr, - nullptr, output); + case kTfLiteFloat32: { + tflite::reference_ops::Conv( + ConvParamsFloat(params, data), tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr); break; - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, data, input, filter, bias, - output, nullptr); + } + case kTfLiteInt8: { + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, data), data.per_channel_output_multiplier, + data.per_channel_output_shift, tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); break; - case kTfLiteUInt8: - EvalQuantized(context, node, params, data, input, filter, bias, nullptr, - nullptr, output); + } + case kTfLiteUInt8: { + reference_ops::Conv(ConvParamsQuantized(params, data), + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output), + tflite::micro::GetTensorShape(nullptr), nullptr, + nullptr); break; + } default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); @@ -329,7 +110,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration Register_CONV_2D() { return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/Prepare, + /*prepare=*/ConvPrepare, /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, diff --git a/tensorflow/lite/micro/kernels/conv.h b/tensorflow/lite/micro/kernels/conv.h new file mode 100644 index 00000000000000..46bc7318b0870f --- /dev/null +++ b/tensorflow/lite/micro/kernels/conv.h @@ -0,0 +1,77 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ +#define TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ + +#include + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/types.h" + +namespace tflite { + +struct OpDataConv { + TfLitePaddingValues padding; + + // Cached tensor zero point values for quantized operations. + int32_t input_zero_point; + int32_t filter_zero_point; + int32_t output_zero_point; + + // The scaling factor from input to output (aka the 'real multiplier') can + // be represented as a fixed point multiplier plus a left shift. + int32_t output_multiplier; + int output_shift; + + // Per channel output multiplier and shift. + int32_t* per_channel_output_multiplier; + int32_t* per_channel_output_shift; + + // The range of the fused activation layer. For example for kNone and + // uint8_t these would be 0 and 255. + int32_t output_activation_min; + int32_t output_activation_max; +}; + +extern const int kConvInputTensor; +extern const int kConvWeightsTensor; +extern const int kConvBiasTensor; +extern const int kConvOutputTensor; +extern const int kConvQuantizedDimension; + +// Returns a ConvParams struct with all the parameters needed for a +// float computation. +ConvParams ConvParamsFloat(const TfLiteConvParams& params, + const OpDataConv& data); + +// Returns a ConvParams struct with all the parameters needed for a +// quantized computation. +ConvParams ConvParamsQuantized(const TfLiteConvParams& params, + const OpDataConv& data); + +TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node, + const TfLiteConvParams& params, int width, + int height, int filter_width, + int filter_height, int out_width, + int out_height, const TfLiteType data_type, + OpDataConv* data); + +TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node); + +} // namespace tflite + +#endif // TENSORFLOW_LITE_MICRO_KERNELS_CONV_H_ diff --git a/tensorflow/lite/micro/kernels/conv_common.cc b/tensorflow/lite/micro/kernels/conv_common.cc new file mode 100644 index 00000000000000..a4a36ae1e1da9c --- /dev/null +++ b/tensorflow/lite/micro/kernels/conv_common.cc @@ -0,0 +1,182 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/lite/c/builtin_op_data.h" +#include "tensorflow/lite/c/common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" +#include "tensorflow/lite/kernels/kernel_util.h" +#include "tensorflow/lite/kernels/padding.h" +#include "tensorflow/lite/micro/kernels/conv.h" +#include "tensorflow/lite/micro/kernels/kernel_util.h" + +namespace tflite { + +const int kConvInputTensor = 0; +const int kConvWeightsTensor = 1; +const int kConvBiasTensor = 2; +const int kConvOutputTensor = 0; + +// Conv is quantized along dimension 0: +// https://www.tensorflow.org/lite/performance/quantization_spec +const int kConvQuantizedDimension = 0; + +// Returns a ConvParams struct with all the parameters needed for a +// float computation. +ConvParams ConvParamsFloat(const TfLiteConvParams& params, + const OpDataConv& data) { + ConvParams op_params; + CalculateActivationRange(params.activation, &op_params.float_activation_min, + &op_params.float_activation_max); + op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding); + op_params.padding_values.width = data.padding.width; + op_params.padding_values.height = data.padding.height; + op_params.stride_width = params.stride_width; + op_params.stride_height = params.stride_height; + op_params.dilation_width_factor = params.dilation_width_factor; + op_params.dilation_height_factor = params.dilation_height_factor; + return op_params; +} + +// Returns a ConvParams struct with all the parameters needed for a +// quantized computation. +ConvParams ConvParamsQuantized(const TfLiteConvParams& params, + const OpDataConv& data) { + ConvParams op_params; + op_params.input_offset = -data.input_zero_point; + op_params.weights_offset = -data.filter_zero_point; + op_params.output_offset = data.output_zero_point; + op_params.output_multiplier = data.output_multiplier; + op_params.output_shift = -data.output_shift; + op_params.padding_type = tflite::micro::RuntimePaddingType(params.padding); + op_params.padding_values.height = data.padding.height; + op_params.padding_values.width = data.padding.width; + op_params.stride_height = params.stride_height; + op_params.stride_width = params.stride_width; + op_params.dilation_height_factor = params.dilation_height_factor; + op_params.dilation_width_factor = params.dilation_width_factor; + op_params.quantized_activation_min = data.output_activation_min; + op_params.quantized_activation_max = data.output_activation_max; + return op_params; +} + +TfLiteStatus CalculateOpDataConv(TfLiteContext* context, TfLiteNode* node, + const TfLiteConvParams& params, int width, + int height, int filter_width, + int filter_height, int out_width, + int out_height, const TfLiteType data_type, + OpDataConv* data) { + bool has_bias = node->inputs->size == 3; + // Check number of inputs/outputs + TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + + // Matching GetWindowedOutputSize in TensorFlow. + auto padding = params.padding; + data->padding = ComputePaddingHeightWidth( + params.stride_height, params.stride_width, params.dilation_height_factor, + params.dilation_width_factor, height, width, filter_height, filter_width, + padding, &out_height, &out_width); + + const TfLiteTensor* input = GetInput(context, node, kConvInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor); + TF_LITE_ENSURE(context, filter != nullptr); + const TfLiteTensor* bias = + GetOptionalInputTensor(context, node, kConvBiasTensor); + TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + + // Note that quantized inference requires that all tensors have their + // parameters set. This is usually done during quantized training. + if (data_type != kTfLiteFloat32) { + int output_channels = filter->dims->data[kConvQuantizedDimension]; + + TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams( + context, input, filter, bias, output, params.activation, + &data->output_multiplier, &data->output_shift, + &data->output_activation_min, &data->output_activation_max, + data->per_channel_output_multiplier, + reinterpret_cast(data->per_channel_output_shift), + output_channels)); + } + + data->input_zero_point = input->params.zero_point; + data->filter_zero_point = filter->params.zero_point; + data->output_zero_point = output->params.zero_point; + + return kTfLiteOk; +} + +TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) { + TFLITE_DCHECK(node->user_data != nullptr); + TFLITE_DCHECK(node->builtin_data != nullptr); + + OpDataConv* data = static_cast(node->user_data); + const auto& params = + *(static_cast(node->builtin_data)); + + TfLiteTensor* output = GetOutput(context, node, kConvOutputTensor); + TF_LITE_ENSURE(context, output != nullptr); + const TfLiteTensor* input = GetInput(context, node, kConvInputTensor); + TF_LITE_ENSURE(context, input != nullptr); + const TfLiteTensor* filter = GetInput(context, node, kConvWeightsTensor); + TF_LITE_ENSURE(context, filter != nullptr); + + const int input_width = input->dims->data[2]; + const int input_height = input->dims->data[1]; + const int filter_width = filter->dims->data[2]; + const int filter_height = filter->dims->data[1]; + const int output_width = output->dims->data[2]; + const int output_height = output->dims->data[1]; + + // Dynamically allocate per-channel quantization parameters. + const int num_channels = filter->dims->data[kConvQuantizedDimension]; + data->per_channel_output_multiplier = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + data->per_channel_output_shift = + static_cast(context->AllocatePersistentBuffer( + context, num_channels * sizeof(int32_t))); + + // All per-channel quantized tensors need valid zero point and scale arrays. + if (input->type == kTfLiteInt8) { + TF_LITE_ENSURE_EQ(context, filter->quantization.type, + kTfLiteAffineQuantization); + + const auto* affine_quantization = + static_cast(filter->quantization.params); + TFLITE_DCHECK(affine_quantization != nullptr); + TFLITE_DCHECK(affine_quantization->scale != nullptr); + TFLITE_DCHECK(affine_quantization->zero_point != nullptr); + + TF_LITE_ENSURE(context, + affine_quantization->scale->size == 1 || + affine_quantization->scale->size == + filter->dims->data[kConvQuantizedDimension]); + TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, + affine_quantization->zero_point->size); + } + + TF_LITE_ENSURE_STATUS(CalculateOpDataConv( + context, node, params, input_width, input_height, filter_width, + filter_height, output_width, output_height, input->type, data)); + + return kTfLiteOk; +} +} // namespace tflite diff --git a/tensorflow/lite/micro/kernels/kernel_util.cc b/tensorflow/lite/micro/kernels/kernel_util.cc index deca92b648fd09..d769f9e5e73757 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.cc +++ b/tensorflow/lite/micro/kernels/kernel_util.cc @@ -37,5 +37,17 @@ const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor) { return RuntimeShape(dims_size, dims_data); } +PaddingType RuntimePaddingType(TfLitePadding padding) { + switch (padding) { + case TfLitePadding::kTfLitePaddingSame: + return PaddingType::kSame; + case TfLitePadding::kTfLitePaddingValid: + return PaddingType::kValid; + case TfLitePadding::kTfLitePaddingUnknown: + default: + return PaddingType::kNone; + } +} + } // namespace micro } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/kernel_util.h b/tensorflow/lite/micro/kernels/kernel_util.h index 79cd58ec04545a..043fb0215ec6ac 100644 --- a/tensorflow/lite/micro/kernels/kernel_util.h +++ b/tensorflow/lite/micro/kernels/kernel_util.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" #include "tensorflow/lite/kernels/internal/types.h" @@ -69,6 +70,8 @@ const RuntimeShape GetTensorShape(const TfLiteEvalTensor* tensor); bool HaveSameShapes(const TfLiteEvalTensor* input1, const TfLiteEvalTensor* input2); +PaddingType RuntimePaddingType(TfLitePadding padding); + } // namespace micro } // namespace tflite diff --git a/tensorflow/lite/micro/kernels/xtensa/conv.cc b/tensorflow/lite/micro/kernels/xtensa/conv.cc index 41a11a86de9fb9..04c2c15af7fcff 100644 --- a/tensorflow/lite/micro/kernels/xtensa/conv.cc +++ b/tensorflow/lite/micro/kernels/xtensa/conv.cc @@ -13,12 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/kernels/internal/reference/conv.h" +#include "tensorflow/lite/micro/kernels/conv.h" #include "tensorflow/lite/c/builtin_op_data.h" #include "tensorflow/lite/c/common.h" #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/conv.h" #include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h" #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -30,36 +31,6 @@ limitations under the License. namespace tflite { namespace { -constexpr int kInputTensor = 0; -constexpr int kFilterTensor = 1; -constexpr int kBiasTensor = 2; -constexpr int kOutputTensor = 0; - -// Conv is quantized along dimension 0: -// https://www.tensorflow.org/lite/performance/quantization_spec -constexpr int kConvQuantizedDimension = 0; - -struct OpData { - TfLitePaddingValues padding; - // The scaling factor from input to output (aka the 'real multiplier') can - // be represented as a fixed point multiplier plus a left shift. - int32_t output_multiplier; - int output_shift; - - // Cached tensor zero point values for quantized operations. - int32_t input_zero_point; - int32_t output_zero_point; - - // Per channel output multiplier and shift. - int32_t* per_channel_output_multiplier; - int32_t* per_channel_output_shift; - - // The range of the fused activation layer. For example for kNone and - // uint8_t these would be 0 and 255. - int32_t output_activation_min; - int32_t output_activation_max; -}; - #if defined(HIFIMINI) void ConvPerChannel(const ConvParams& params, const int32_t* output_multiplier, const int32_t* output_shift, @@ -263,164 +234,27 @@ inline void Conv1x32Input32x32Filter( } #endif -TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, int width, int height, - int filter_width, int filter_height, int out_width, - int out_height, const TfLiteType data_type, - OpData* data) { - bool has_bias = node->inputs->size == 3; - // Check number of inputs/outputs - TF_LITE_ENSURE(context, has_bias || node->inputs->size == 2); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); - - // Matching GetWindowedOutputSize in TensorFlow. - auto padding = params->padding; - data->padding = ComputePaddingHeightWidth( - params->stride_height, params->stride_width, - params->dilation_height_factor, params->dilation_width_factor, height, - width, filter_height, filter_width, padding, &out_height, &out_width); - - // Note that quantized inference requires that all tensors have their - // parameters set. This is usually done during quantized training. - if (data_type != kTfLiteFloat32) { - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - const TfLiteTensor* bias = - GetOptionalInputTensor(context, node, kBiasTensor); - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - int output_channels = filter->dims->data[kConvQuantizedDimension]; - - return tflite::PopulateConvolutionQuantizationParams( - context, input, filter, bias, output, params->activation, - &data->output_multiplier, &data->output_shift, - &data->output_activation_min, &data->output_activation_max, - data->per_channel_output_multiplier, - reinterpret_cast(data->per_channel_output_shift), - output_channels); - } - return kTfLiteOk; -} - void* Init(TfLiteContext* context, const char* buffer, size_t length) { TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr); - return context->AllocatePersistentBuffer(context, sizeof(OpData)); -} - -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - TFLITE_DCHECK(node->user_data != nullptr); - TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = reinterpret_cast(node->builtin_data); - - TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - const TfLiteTensor* input = GetInput(context, node, kInputTensor); - const TfLiteTensor* filter = GetInput(context, node, kFilterTensor); - - auto* op_data = reinterpret_cast(node->user_data); - - int input_width = input->dims->data[2]; - int input_height = input->dims->data[1]; - int filter_width = filter->dims->data[2]; - int filter_height = filter->dims->data[1]; - int output_width = output->dims->data[2]; - int output_height = output->dims->data[1]; - - // Per channel quantization is only needed for int8_t inference. For other - // quantized types, only a single scale and zero point is needed. - const int num_channels = filter->dims->data[kConvQuantizedDimension]; - // Dynamically allocate per-channel quantization parameters. - op_data->per_channel_output_multiplier = - reinterpret_cast(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - op_data->per_channel_output_shift = - reinterpret_cast(context->AllocatePersistentBuffer( - context, num_channels * sizeof(int32_t))); - op_data->input_zero_point = input->params.zero_point; - op_data->output_zero_point = output->params.zero_point; - // All per-channel quantized tensors need valid zero point and scale arrays. - if (input->type == kTfLiteInt8) { - TF_LITE_ENSURE_EQ(context, filter->quantization.type, - kTfLiteAffineQuantization); - - const auto* affine_quantization = - reinterpret_cast( - filter->quantization.params); - TF_LITE_ENSURE(context, affine_quantization); - TF_LITE_ENSURE(context, affine_quantization->scale); - TF_LITE_ENSURE(context, affine_quantization->zero_point); - - TF_LITE_ENSURE(context, - affine_quantization->scale->size == 1 || - affine_quantization->scale->size == - filter->dims->data[kConvQuantizedDimension]); - TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size, - affine_quantization->zero_point->size); - } - - return CalculateOpData(context, node, params, input_width, input_height, - filter_width, filter_height, output_width, - output_height, input->type, op_data); -} - -void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node, - TfLiteConvParams* params, OpData* data, - const TfLiteEvalTensor* input, - const TfLiteEvalTensor* filter, - const TfLiteEvalTensor* bias, - TfLiteEvalTensor* output, - TfLiteEvalTensor* im2col) { - // TODO(b/154032858): Investigate removing extra copies. - ConvParams op_params; - op_params.input_offset = -data->input_zero_point; - op_params.output_offset = data->output_zero_point; - op_params.stride_height = params->stride_height; - op_params.stride_width = params->stride_width; - op_params.dilation_height_factor = params->dilation_height_factor; - op_params.dilation_width_factor = params->dilation_width_factor; - op_params.padding_values.height = data->padding.height; - op_params.padding_values.width = data->padding.width; - op_params.quantized_activation_min = data->output_activation_min; - op_params.quantized_activation_max = data->output_activation_max; - -#if defined(HIFIMINI) - ConvPerChannel(op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, - tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); -#else - reference_integer_ops::ConvPerChannel( - op_params, data->per_channel_output_multiplier, - data->per_channel_output_shift, tflite::micro::GetTensorShape(input), - tflite::micro::GetTensorData(input), - tflite::micro::GetTensorShape(filter), - tflite::micro::GetTensorData(filter), - tflite::micro::GetTensorShape(bias), - tflite::micro::GetTensorData(bias), - tflite::micro::GetTensorShape(output), - tflite::micro::GetTensorData(output)); -#endif + return context->AllocatePersistentBuffer(context, sizeof(OpDataConv)); } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TFLITE_DCHECK(node->user_data != nullptr); TFLITE_DCHECK(node->builtin_data != nullptr); - auto* params = reinterpret_cast(node->builtin_data); - auto* op_data = reinterpret_cast(node->user_data); + const auto& params = + *(reinterpret_cast(node->builtin_data)); + const auto& op_data = *(reinterpret_cast(node->user_data)); TfLiteEvalTensor* output = - tflite::micro::GetEvalOutput(context, node, kOutputTensor); + tflite::micro::GetEvalOutput(context, node, kConvOutputTensor); const TfLiteEvalTensor* input = - tflite::micro::GetEvalInput(context, node, kInputTensor); + tflite::micro::GetEvalInput(context, node, kConvInputTensor); const TfLiteEvalTensor* filter = - tflite::micro::GetEvalInput(context, node, kFilterTensor); + tflite::micro::GetEvalInput(context, node, kConvWeightsTensor); const TfLiteEvalTensor* bias = (NumInputs(node) == 3) - ? tflite::micro::GetEvalInput(context, node, kBiasTensor) + ? tflite::micro::GetEvalInput(context, node, kConvBiasTensor) : nullptr; #if defined(HIFIMINI) @@ -430,10 +264,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { input_dims[3] == 32 && filter_dims[0] == 32 && filter_dims[1] == 1 && filter_dims[2] == 1 && filter_dims[3] == 32) { Conv1x32Input32x32Filter( - -op_data->input_zero_point, op_data->output_zero_point, - op_data->output_activation_min, op_data->output_activation_max, - op_data->per_channel_output_multiplier, - op_data->per_channel_output_shift, tflite::micro::GetTensorShape(input), + -op_data.input_zero_point, op_data.output_zero_point, + op_data.output_activation_min, op_data.output_activation_max, + op_data.per_channel_output_multiplier, op_data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), tflite::micro::GetTensorData(input), tflite::micro::GetTensorShape(filter), tflite::micro::GetTensorData(filter), @@ -446,10 +280,35 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { #endif switch (input->type) { - case kTfLiteInt8: - EvalQuantizedPerChannel(context, node, params, op_data, input, filter, - bias, output, nullptr); + case kTfLiteInt8: { +#if defined(HIFIMINI) + ConvPerChannel(ConvParamsQuantized(params, op_data), + op_data.per_channel_output_multiplier, + op_data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +#else + reference_integer_ops::ConvPerChannel( + ConvParamsQuantized(params, op_data), + op_data.per_channel_output_multiplier, + op_data.per_channel_output_shift, + tflite::micro::GetTensorShape(input), + tflite::micro::GetTensorData(input), + tflite::micro::GetTensorShape(filter), + tflite::micro::GetTensorData(filter), + tflite::micro::GetTensorShape(bias), + tflite::micro::GetTensorData(bias), + tflite::micro::GetTensorShape(output), + tflite::micro::GetTensorData(output)); +#endif break; + } default: TF_LITE_KERNEL_LOG(context, "Type %s (%d) not supported.", TfLiteTypeGetName(input->type), input->type); @@ -462,7 +321,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteRegistration Register_CONV_2D() { return {/*init=*/Init, /*free=*/nullptr, - /*prepare=*/Prepare, + /*prepare=*/ConvPrepare, /*invoke=*/Eval, /*profiling_string=*/nullptr, /*builtin_code=*/0, diff --git a/tensorflow/lite/micro/tools/make/Makefile b/tensorflow/lite/micro/tools/make/Makefile index 73d7fedd48a22c..5ee2586811a5b2 100644 --- a/tensorflow/lite/micro/tools/make/Makefile +++ b/tensorflow/lite/micro/tools/make/Makefile @@ -316,6 +316,7 @@ tensorflow/lite/micro/kernels/circular_buffer.cc \ tensorflow/lite/micro/kernels/comparisons.cc \ tensorflow/lite/micro/kernels/concatenation.cc \ tensorflow/lite/micro/kernels/conv.cc \ +tensorflow/lite/micro/kernels/conv_common.cc \ tensorflow/lite/micro/kernels/conv_test_common.cc \ tensorflow/lite/micro/kernels/depthwise_conv.cc \ tensorflow/lite/micro/kernels/dequantize.cc \