diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index fefb07b1e59..e52b9a371e6 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -461,51 +462,104 @@ Tensor& quantize_per_channel_out( #if defined(__aarch64__) || defined(__ARM_NEON__) if (input.scalar_type() == ScalarType::Float) { const int64_t num_blocks = input.numel() / axis_block_size; + const int64_t total_elements = input.numel(); + constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512; + const bool use_parallel = (total_elements >= MIN_ELEMENTS_FOR_PARALLEL); if (dtype == ScalarType::Byte) { auto* out_data_ptr = out.mutable_data_ptr(); const auto* input_data_ptr = input.const_data_ptr(); - // Process each contiguous block (which shares the same scale/zero_point) - for (int64_t block = 0; block < num_blocks; ++block) { - int64_t channel_idx = block % axis_size; - float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); - int32_t zp = static_cast(zero_point_data[channel_idx]); - - const float* in_ptr = input_data_ptr + block * axis_block_size; - uint8_t* out_ptr = out_data_ptr + block * axis_block_size; - - quantize_arm( - in_ptr, - out_ptr, - axis_block_size, - inv_scale, - zp, - static_cast(quant_min), - static_cast(quant_max)); + if (use_parallel) { + ::executorch::extension::parallel_for( + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t block = begin; block < end; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = + 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + }); + } else { + // Process each contiguous block (which shares the same + // scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } } return out; } else if (dtype == ScalarType::Char) { auto* out_data_ptr = out.mutable_data_ptr(); const auto* input_data_ptr = input.const_data_ptr(); - // Process each contiguous block (which shares the same scale/zero_point) - for (int64_t block = 0; block < num_blocks; ++block) { - int64_t channel_idx = block % axis_size; - float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); - int32_t zp = static_cast(zero_point_data[channel_idx]); - - const float* in_ptr = input_data_ptr + block * axis_block_size; - int8_t* out_ptr = out_data_ptr + block * axis_block_size; - - quantize_arm( - in_ptr, - out_ptr, - axis_block_size, - inv_scale, - zp, - static_cast(quant_min), - static_cast(quant_max)); + if (use_parallel) { + ::executorch::extension::parallel_for( + 0, num_blocks, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t block = begin; block < end; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = + 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + int8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + }); + } else { + // Process each contiguous block (which shares the same + // scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + int8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } } return out; } diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index 1da0d482485..88a3823c5f3 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -52,6 +52,9 @@ _QUANT_OPS = ( ), op_target( name = "op_quantize", + deps = [ + "//executorch/extension/threadpool:threadpool", + ], ), )