From 3f789b8ed0e367c46eafc0ac44ff44cf444712ef Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 5 Nov 2025 10:44:58 -0800 Subject: [PATCH] [Executorch] parallelize op_choose_qparams When doing prefill for quantized kv cache, with large prefill length, parallelizing this op helps. Differential Revision: [D84962234](https://our.internmc.facebook.com/intern/diff/D84962234/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84962234/)! [ghstack-poisoned] --- kernels/quantized/cpu/op_choose_qparams.cpp | 48 +++++++--- kernels/quantized/cpu/targets.bzl | 1 + .../quantized/test/op_choose_qparams_test.cpp | 95 +++++++++++++++++++ 3 files changed, 133 insertions(+), 11 deletions(-) diff --git a/kernels/quantized/cpu/op_choose_qparams.cpp b/kernels/quantized/cpu/op_choose_qparams.cpp index 5335f4bfbd2..acb8e100af6 100644 --- a/kernels/quantized/cpu/op_choose_qparams.cpp +++ b/kernels/quantized/cpu/op_choose_qparams.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -202,17 +203,42 @@ void choose_qparams_per_token( num_tokens *= input.size(i); } auto token_dim_size = input.size(input.dim() - 1); - for (auto i = 0; i < num_tokens; i++) { - // vec_minf uses std::min_element. Check if it actually - // gets vectorized. - float min = torch::executor::vec_minf(x_fp32, token_dim_size); - float max = torch::executor::vec_maxf(x_fp32, token_dim_size); - double scale; - int32_t zero_point; - calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); - scale_out.mutable_data_ptr()[i] = scale; - zero_point_out.mutable_data_ptr()[i] = zero_point; - x_fp32 += token_dim_size; + + const int64_t total_elements = num_tokens * token_dim_size; + constexpr int64_t MIN_ELEMENTS_FOR_PARALLEL = 512; + const bool use_parallel = total_elements >= MIN_ELEMENTS_FOR_PARALLEL; + + if (use_parallel) { + auto* scale_data = scale_out.mutable_data_ptr(); + auto* zero_point_data = zero_point_out.mutable_data_ptr(); + + ::executorch::extension::parallel_for( + 0, num_tokens, 1, [&](const int64_t begin, const int64_t end) { + for (int64_t i = begin; i < end; i++) { + const float* token_data = x_fp32 + i * token_dim_size; + float min = torch::executor::vec_minf(token_data, token_dim_size); + float max = torch::executor::vec_maxf(token_data, token_dim_size); + double scale; + int32_t zero_point; + calculate_scale_and_zero_point( + min, max, qmin, qmax, scale, zero_point); + scale_data[i] = scale; + zero_point_data[i] = zero_point; + } + }); + } else { + for (auto i = 0; i < num_tokens; i++) { + // vec_minf uses std::min_element. Check if it actually + // gets vectorized. + float min = torch::executor::vec_minf(x_fp32, token_dim_size); + float max = torch::executor::vec_maxf(x_fp32, token_dim_size); + double scale; + int32_t zero_point; + calculate_scale_and_zero_point(min, max, qmin, qmax, scale, zero_point); + scale_out.mutable_data_ptr()[i] = scale; + zero_point_out.mutable_data_ptr()[i] = zero_point; + x_fp32 += token_dim_size; + } } } } // namespace diff --git a/kernels/quantized/cpu/targets.bzl b/kernels/quantized/cpu/targets.bzl index f29f1f013b7..1da0d482485 100644 --- a/kernels/quantized/cpu/targets.bzl +++ b/kernels/quantized/cpu/targets.bzl @@ -9,6 +9,7 @@ _QUANT_OPS = ( name = "op_choose_qparams", deps = [ "//executorch/kernels/portable/cpu:vec_ops", + "//executorch/extension/threadpool:threadpool", ], ), op_target( diff --git a/kernels/quantized/test/op_choose_qparams_test.cpp b/kernels/quantized/test/op_choose_qparams_test.cpp index 13426bfdd86..dc92df80488 100644 --- a/kernels/quantized/test/op_choose_qparams_test.cpp +++ b/kernels/quantized/test/op_choose_qparams_test.cpp @@ -15,6 +15,7 @@ #include #include +#include #include using namespace ::testing; @@ -163,3 +164,97 @@ TEST(OpChooseQparamsPerTokenAsymmetricTensorOutTest, DynamicShapeFloat) { EXPECT_TENSOR_CLOSE_WITH_TOL(scale_out, new_expected_scale, 1e-4, 1e-4); EXPECT_TENSOR_EQ(zero_point_out, new_expected_zero_point); } + +TEST( + OpChooseQparamsPerTokenAsymmetricTensorOutTest, + LargeInputParallelization) { + et_pal_init(); + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + // Create input with 8 tokens x 128 elements per token = 1024 total elements + // This exceeds the MIN_ELEMENTS_FOR_PARALLEL threshold of 512 + const int num_tokens = 8; + const int token_size = 128; + std::vector input_data(num_tokens * token_size); + + // Generate test data with known min/max per token for easier verification + std::vector expected_min(num_tokens); + std::vector expected_max(num_tokens); + + for (int i = 0; i < num_tokens; i++) { + float token_min = -1.0f * (i + 1); + float token_max = 2.0f * (i + 1); + expected_min[i] = token_min; + expected_max[i] = token_max; + + for (int j = 0; j < token_size; j++) { + // Linearly interpolate between min and max + float t = j / static_cast(token_size - 1); + input_data[i * token_size + j] = token_min + t * (token_max - token_min); + } + } + + Tensor input = tf_float.make({num_tokens, token_size}, input_data); + Tensor scale_out = tf_double.zeros({num_tokens, 1}); + Tensor zero_point_out = tf_long.zeros({num_tokens, 1}); + + choose_qparams_per_token_asymmetric_out( + input, ScalarType::Float, scale_out, zero_point_out); + + // Manually calculate expected scale and zero_point using the same algorithm + // as calculate_scale_and_zero_point function + const int32_t qmin = -128; + const int32_t qmax = 127; + const float SMALL_SCALE_THRESHOLD = 6.1e-5f; + + for (int i = 0; i < num_tokens; i++) { + float min = std::min(expected_min[i], 0.0f); + float max = std::max(expected_max[i], 0.0f); + + // Calculate scale + double scale = (static_cast(max) - min) / (qmax - qmin); + if (float(scale) == 0.0f || std::isinf(1.0f / float(scale))) { + scale = 0.1; + } + + // Cut off small scale + if (scale < SMALL_SCALE_THRESHOLD) { + scale = SMALL_SCALE_THRESHOLD; + if (min == 0.0f) { + max = SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else if (max == 0.0f) { + min = -SMALL_SCALE_THRESHOLD * (qmax - qmin); + } else { + float amplifier = SMALL_SCALE_THRESHOLD / scale; + min *= amplifier; + max *= amplifier; + } + } + + // Calculate zero_point + double zero_point_from_min = qmin - min / scale; + double zero_point_from_max = qmax - max / scale; + double zero_point_from_min_error = std::abs(qmin) - std::abs(min / scale); + double zero_point_from_max_error = std::abs(qmax) - std::abs(max / scale); + double initial_zero_point = + zero_point_from_min_error < zero_point_from_max_error + ? zero_point_from_min + : zero_point_from_max; + + int32_t nudged_zero_point = 0; + if (initial_zero_point < qmin) { + nudged_zero_point = qmin; + } else if (initial_zero_point > qmax) { + nudged_zero_point = qmax; + } else { + nudged_zero_point = + std::nearbyint(static_cast(initial_zero_point)); + } + + // Verify computed values match expected + EXPECT_NEAR(scale_out.const_data_ptr()[i], scale, 1e-6); + EXPECT_EQ(zero_point_out.const_data_ptr()[i], nudged_zero_point); + } +}