From 08dd9803160c8e2a0f13ba8ec1e6259955a14c62 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Wed, 5 Nov 2025 10:45:02 -0800 Subject: [PATCH] [Executorch] Add simd path for op quantize Reason this doesnt directly use Vectorize class is because the equivalent APIs dont exist in Vectorize class Differential Revision: [D84962236](https://our.internmc.facebook.com/intern/diff/D84962236/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D84962236/)! [ghstack-poisoned] --- kernels/quantized/cpu/op_quantize.cpp | 294 +++++++++-- kernels/quantized/test/op_quantize_test.cpp | 537 +++++++++++++++++++- 2 files changed, 794 insertions(+), 37 deletions(-) diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 5586f8a77eb..fefb07b1e59 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -11,6 +11,10 @@ #include #include +#if defined(__aarch64__) || defined(__ARM_NEON__) +#include +#endif + /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ @@ -105,6 +109,143 @@ T quantize_val( return static_cast(qvalue); } +#if defined(__aarch64__) || defined(__ARM_NEON__) + +// Traits for type-specific NEON operations +template +struct NeonQuantizeTraits; + +template <> +struct NeonQuantizeTraits { + // Narrow int16x8 to uint8x8 with saturation (unsigned) + static inline uint8x8_t narrow_and_saturate(int16x8_t v) { + return vqmovun_s16(v); + } + + // Store uint8x8 to memory + static inline void store(uint8_t* ptr, uint8x8_t v) { + vst1_u8(ptr, v); + } + + // Scalar clamping for uint8 + static inline uint8_t clamp_scalar(int32_t val) { + return static_cast(std::min(255, std::max(0, val))); + } +}; + +template <> +struct NeonQuantizeTraits { + // Narrow int16x8 to int8x8 with saturation (signed) + static inline int8x8_t narrow_and_saturate(int16x8_t v) { + return vqmovn_s16(v); + } + + // Store int8x8 to memory + static inline void store(int8_t* ptr, int8x8_t v) { + vst1_s8(ptr, v); + } + + // Scalar clamping for int8 + static inline int8_t clamp_scalar(int32_t val) { + return static_cast(std::min(127, std::max(-128, val))); + } +}; + +// Unified ARM NEON optimized quantization for contiguous blocks +// Processes N elements with a single scale/zero_point pair +// Used for both per-tensor (entire tensor) and per-channel (one block per +// channel) +template +void quantize_arm( + const float* __restrict__ in, + T* __restrict__ out, + const int64_t N, + const float inv_scale, + const int32_t zero_point, + const int32_t quant_min, + const int32_t quant_max) { + using Traits = NeonQuantizeTraits; + const float32x4_t vinv_scale = vdupq_n_f32(inv_scale); + +#if defined(__aarch64__) + // ARMv8: Use vcvtnq_s32_f32 for rounding + const int16x8_t vzero_point = vdupq_n_s16(static_cast(zero_point)); + const int16x8_t vquant_min = vdupq_n_s16(static_cast(quant_min)); + const int16x8_t vquant_max = vdupq_n_s16(static_cast(quant_max)); + + int64_t i = 0; + // Process 8 elements at a time + for (; i + 8 <= N; i += 8) { + const float32x4_t vin0123 = vld1q_f32(in + i); + const float32x4_t vin4567 = vld1q_f32(in + i + 4); + + // Multiply by inv_scale and round + const int32x4_t v0123_rounded = + vcvtnq_s32_f32(vmulq_f32(vin0123, vinv_scale)); + const int32x4_t v4567_rounded = + vcvtnq_s32_f32(vmulq_f32(vin4567, vinv_scale)); + + // Combine to int16 and add zero_point + int16x8_t v01234567_packed = vqaddq_s16( + vqmovn_high_s32(vqmovn_s32(v0123_rounded), v4567_rounded), vzero_point); + + // Clamp to quant_min/quant_max + v01234567_packed = vmaxq_s16(v01234567_packed, vquant_min); + v01234567_packed = vminq_s16(v01234567_packed, vquant_max); + + // Convert to T (int8/uint8) with saturation using type-specific operation + const auto vout01234567 = Traits::narrow_and_saturate(v01234567_packed); + Traits::store(out + i, vout01234567); + } + + // Handle remaining elements with proper quant_min/quant_max clamping + for (; i < N; ++i) { + float val = in[i] * inv_scale; + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::max(quant_min, std::min(quant_max, qval)); + out[i] = static_cast(qval); + } + +#else + // ARMv7: Use magic float rounding + const int32x4_t voffset = vdupq_n_s32(zero_point - 0x4B400000); + const float32x4_t vmagic_float = vdupq_n_f32(12582912.0f); + + int64_t i = 0; + // Process 8 elements at a time + for (; i + 8 <= N; i += 8) { + const float32x4_t vin0123 = vld1q_f32(in + i); + const float32x4_t vin4567 = vld1q_f32(in + i + 4); + + const int32x4_t vraw0123 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin0123, vinv_scale)))); + const int32x4_t vraw4567 = vaddq_s32( + voffset, + vreinterpretq_s32_f32( + vaddq_f32(vmagic_float, vmulq_f32(vin4567, vinv_scale)))); + + const int16x8_t vraw01234567 = + vcombine_s16(vqmovn_s32(vraw0123), vqmovn_s32(vraw4567)); + + // Convert to T (int8/uint8) with saturation using type-specific operation + const auto vout01234567 = Traits::narrow_and_saturate(vraw01234567); + Traits::store(out + i, vout01234567); + } + + // Handle remaining elements with proper quant_min/quant_max clamping + for (; i < N; ++i) { + float val = in[i] * inv_scale; + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::max(quant_min, std::min(quant_max, qval)); + out[i] = static_cast(qval); + } +#endif +} + +#endif // defined(__aarch64__) || defined(__ARM_NEON__) + Tensor& quantize_per_tensor_out( const Tensor& input, double scale, @@ -120,19 +261,44 @@ Tensor& quantize_per_tensor_out( check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - // calculate the quantized input -#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - scale, zero_point, value, quant_min, quant_max); \ - } \ + // Try ARM NEON optimized path for float->int8/uint8 quantization +#if defined(__aarch64__) || defined(__ARM_NEON__) + if (input.scalar_type() == ScalarType::Float) { + if (dtype == ScalarType::Byte) { + quantize_arm( + input.const_data_ptr(), + out.mutable_data_ptr(), + input.numel(), + 1.0f / static_cast(scale), + static_cast(zero_point), + static_cast(quant_min), + static_cast(quant_max)); + return out; + } else if (dtype == ScalarType::Char) { + quantize_arm( + input.const_data_ptr(), + out.mutable_data_ptr(), + input.numel(), + 1.0f / static_cast(scale), + static_cast(zero_point), + static_cast(quant_min), + static_cast(quant_max)); + return out; + } + } +#endif + + // Fallback scalar implementation for all other cases +#define QUANTIZE_IMPL(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + scale, zero_point, value, quant_min, quant_max); \ + } \ } break; #define CALCULATE_FLOAT_TYPE(IN_CTYPE, in_dtype) \ case ScalarType::in_dtype: \ @@ -284,29 +450,85 @@ Tensor& quantize_per_channel_out( const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data = zero_point.const_data_ptr(); - // High-performance single loop with direct channel calculation -#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const int64_t input_numel = input.numel(); \ - const int64_t axis_size = input.size(axis); \ - /* Calculate the stride pattern for efficient channel index calculation */ \ - int64_t axis_block_size = 1; \ - for (int64_t i = axis + 1; i < input.dim(); i++) { \ - axis_block_size *= input.size(i); \ - } \ - /* Single loop over all elements */ \ - for (int64_t i = 0; i < input_numel; i++) { \ - /* Calculate which channel this element belongs to */ \ - int64_t channel_idx = (i / axis_block_size) % axis_size; \ - /* Get quantization parameters for this channel */ \ - double _scale = scale_data[channel_idx]; \ - int64_t _zero_point = zero_point_data[channel_idx]; \ - /* Apply quantization */ \ - out_data_ptr[i] = quantize_val( \ - _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ - } \ + // Calculate the block size for each channel + int64_t axis_block_size = 1; + for (int64_t i = axis + 1; i < input.dim(); i++) { + axis_block_size *= input.size(i); + } + const int64_t axis_size = input.size(axis); + + // Try ARM NEON optimized path for float->int8/uint8 quantization +#if defined(__aarch64__) || defined(__ARM_NEON__) + if (input.scalar_type() == ScalarType::Float) { + const int64_t num_blocks = input.numel() / axis_block_size; + + if (dtype == ScalarType::Byte) { + auto* out_data_ptr = out.mutable_data_ptr(); + const auto* input_data_ptr = input.const_data_ptr(); + + // Process each contiguous block (which shares the same scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + uint8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + return out; + } else if (dtype == ScalarType::Char) { + auto* out_data_ptr = out.mutable_data_ptr(); + const auto* input_data_ptr = input.const_data_ptr(); + + // Process each contiguous block (which shares the same scale/zero_point) + for (int64_t block = 0; block < num_blocks; ++block) { + int64_t channel_idx = block % axis_size; + float inv_scale = 1.0f / static_cast(scale_data[channel_idx]); + int32_t zp = static_cast(zero_point_data[channel_idx]); + + const float* in_ptr = input_data_ptr + block * axis_block_size; + int8_t* out_ptr = out_data_ptr + block * axis_block_size; + + quantize_arm( + in_ptr, + out_ptr, + axis_block_size, + inv_scale, + zp, + static_cast(quant_min), + static_cast(quant_max)); + } + return out; + } + } +#endif + + // Fallback scalar implementation +#define QUANTIZE_IMPL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const int64_t input_numel = input.numel(); \ + /* Single loop over all elements */ \ + for (int64_t i = 0; i < input_numel; i++) { \ + /* Calculate which channel this element belongs to */ \ + int64_t channel_idx = (i / axis_block_size) % axis_size; \ + /* Get quantization parameters for this channel */ \ + double _scale = scale_data[channel_idx]; \ + int64_t _zero_point = zero_point_data[channel_idx]; \ + /* Apply quantization */ \ + out_data_ptr[i] = quantize_val( \ + _scale, _zero_point, input_data_ptr[i], quant_min, quant_max); \ + } \ } break; #define CALCULATE_FLOAT_TYPE(CTYPE_IN, in_dtype) \ diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index 4ac835c24ce..b450ec0ee33 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -14,7 +14,6 @@ #include #include -#include using namespace ::testing; using executorch::aten::ArrayRef; @@ -446,3 +445,539 @@ TEST(OpQuantizeOutTest, QuantizePerChannelClampingBehavior) { EXPECT_TENSOR_EQ(out, expected); } + +TEST(OpQuantizeOutTest, LargePerChannelClampingSIMDPath) { + // Test quant_min/quant_max clamping with large tensor to exercise SIMD path + // Shape: [3, 80] with axis=0 (3 channels, 80 elements each) + // 80 elements = 10 SIMD iterations (8 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 3; + const int block_size = 80; + std::vector input_data(num_channels * block_size); + + // Create input data with values that exceed quant_min/quant_max + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + // Generate values from -150 to 150 to test clamping + input_data[ch * block_size + i] = + static_cast((i % 40) - 20) * 5.0f * (ch + 1); + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Use uniform scale and zero_point for all channels + Tensor scale = tf_double.make({num_channels}, {1.0, 1.0, 1.0}); + Tensor zero_point = tf_long.make({num_channels}, {0, 0, 0}); + + // Set narrow quant_min/quant_max to force clamping + int64_t quant_min = -20; + int64_t quant_max = 20; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values with clamping + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + // Use double precision to avoid overflow + double val = static_cast(input_data[idx]) / ch_scale; + // Clamp before converting to int to avoid overflow + val = std::max(-1000.0, std::min(1000.0, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + // Apply quant_min/quant_max clamping + qval = std::max( + static_cast(quant_min), + std::min(static_cast(quant_max), qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +// Large tensor tests to ensure ARM NEON SIMD path is exercised + +TEST(OpQuantizeOutTest, LargeTensorUInt8SIMDPath) { + // Test with 64 elements to fully exercise SIMD path (8 elements per + // iteration) + TensorFactory tf_float; + + // Create input with known values for verification + std::vector input_data(64); + for (size_t i = 0; i < 64; i++) { + input_data[i] = static_cast(i) * 0.5f; // 0.0, 0.5, 1.0, 1.5, ... + } + Tensor input = tf_float.make({64}, input_data); + + double scale = 0.1; + int64_t zero_point = 10; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({64}); + + // Compute expected values: round(value / scale) + zero_point + std::vector expected_data(64); + for (size_t i = 0; i < 64; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({64}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorInt8SIMDPath) { + // Test with 72 elements (9 SIMD iterations of 8) to test both vectorized and + // scalar paths + TensorFactory tf_float; + + std::vector input_data(72); + for (size_t i = 0; i < 72; i++) { + // Mix of positive and negative values + input_data[i] = static_cast(static_cast(i) - 36) * 0.25f; + } + Tensor input = tf_float.make({72}, input_data); + + double scale = 0.2; + int64_t zero_point = 0; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({72}); + + // Compute expected values + std::vector expected_data(72); + for (size_t i = 0; i < 72; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({72}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorWithRemainderUInt8) { + // Test with 100 elements (12 SIMD iterations + 4 remainder) to test remainder + // handling + TensorFactory tf_float; + + std::vector input_data(100); + for (size_t i = 0; i < 100; i++) { + input_data[i] = static_cast(i % 50) * 0.3f; + } + Tensor input = tf_float.make({100}, input_data); + + double scale = 0.15; + int64_t zero_point = 128; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({100}); + + std::vector expected_data(100); + for (size_t i = 0; i < 100; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({100}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargeTensorWithRemainderInt8) { + // Test with 99 elements (12 SIMD iterations + 3 remainder) + TensorFactory tf_float; + + std::vector input_data(99); + for (size_t i = 0; i < 99; i++) { + input_data[i] = std::sin(static_cast(i) * 0.1f) * 10.0f; + } + Tensor input = tf_float.make({99}, input_data); + + double scale = 0.1; + int64_t zero_point = 5; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({99}); + + std::vector expected_data(99); + for (size_t i = 0; i < 99; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({99}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargeTensor2DUInt8) { + // Test with realistic 2D tensor size that would be used in neural networks + // 256x256 = 65536 elements (8192 SIMD iterations) + TensorFactory tf_float; + + std::vector input_data(256 * 256); + for (size_t i = 0; i < 256 * 256; i++) { + // Generate diverse values in a safe range + input_data[i] = + static_cast((static_cast(i % 256) - 128)) * 0.05f; + } + Tensor input = tf_float.make({256, 256}, input_data); + + double scale = 0.05; + int64_t zero_point = 128; + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({256, 256}); + + // Compute expected values with proper overflow handling + std::vector expected_data(256 * 256); + for (size_t i = 0; i < 256 * 256; i++) { + // Use double precision to avoid overflow + double val = static_cast(input_data[i]) / scale; + // Clamp before converting to int to avoid overflow + val = std::max(-1000.0, std::min(1000.0, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({256, 256}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargeTensor3DInt8) { + // Test with 3D tensor (batch_size=2, height=64, width=128) = 16384 elements + TensorFactory tf_float; + + const size_t total_elements = 2 * 64 * 128; + std::vector input_data(total_elements); + for (size_t i = 0; i < total_elements; i++) { + input_data[i] = std::cos(static_cast(i) * 0.01f) * 8.0f; + } + Tensor input = tf_float.make({2, 64, 128}, input_data); + + double scale = 0.0625; // 1/16 + int64_t zero_point = -10; + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({2, 64, 128}); + + std::vector expected_data(total_elements); + for (size_t i = 0; i < total_elements; i++) { + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(127, std::max(-128, qval)); + expected_data[i] = static_cast(qval); + } + Tensor expected = tfo.make({2, 64, 128}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, EdgeCaseSizesSIMD) { + // Test specific sizes around SIMD boundaries + TensorFactory tf_float; + TensorFactory tfo; + + double scale = 0.1; + int64_t zero_point = 100; + int64_t quant_min = 0; + int64_t quant_max = 255; + + // Test sizes: 7 (just before SIMD), 8 (exactly 1 SIMD), 9 (1 SIMD + 1), 15, + // 16, 17 + std::vector test_sizes = { + 7, 8, 9, 15, 16, 17, 23, 24, 25, 31, 32, 33}; + + for (size_t size : test_sizes) { + std::vector input_data(size); + std::vector expected_data(size); + + for (size_t i = 0; i < size; i++) { + input_data[i] = static_cast(i) * 0.3f; + float val = input_data[i] / static_cast(scale); + int32_t qval = static_cast(std::nearbyint(val)) + zero_point; + qval = std::min(255, std::max(0, qval)); + expected_data[i] = static_cast(qval); + } + + Tensor input = tf_float.make({static_cast(size)}, input_data); + Tensor out = tfo.zeros({static_cast(size)}); + Tensor expected = tfo.make({static_cast(size)}, expected_data); + + quantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); + } +} + +// Large tensor tests for per-channel quantization to ensure SIMD path is +// exercised + +TEST(OpQuantizeOutTest, LargePerChannelUInt8SIMDPath) { + // Test per-channel quantization with large blocks (64 elements per channel) + // Shape: [4, 64] with axis=1 (4 channels, 64 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 4; + const int block_size = 64; + std::vector input_data(num_channels * block_size); + + // Create varying input data for each channel + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + input_data[ch * block_size + i] = static_cast((ch + 1) * i) * 0.1f; + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Different scale and zero_point for each channel + Tensor scale = tf_double.make({num_channels}, {0.1, 0.2, 0.15, 0.25}); + Tensor zero_point = tf_long.make({num_channels}, {10, 20, 15, 25}); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, LargePerChannelInt8SIMDPath) { + // Test per-channel quantization with int8 and large blocks + // Shape: [3, 100] with axis=1 (3 channels, 100 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 3; + const int block_size = 100; // 12 SIMD iterations + 4 remainder + std::vector input_data(num_channels * block_size); + + // Create varying input data with negative values + for (int ch = 0; ch < num_channels; ch++) { + for (int i = 0; i < block_size; i++) { + input_data[ch * block_size + i] = + static_cast(i - 50) * 0.2f * (ch + 1); + } + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + Tensor scale = tf_double.make({num_channels}, {0.1, 0.15, 0.2}); + Tensor zero_point = tf_long.make({num_channels}, {0, -5, 5}); + + int64_t quant_min = -128; + int64_t quant_max = 127; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(num_channels * block_size); + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(127, std::max(-128, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Char, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, VeryLargePerChannel2DUInt8) { + // Test realistic neural network weight tensor + // Shape: [128, 256] with axis=0 (128 channels, 256 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int num_channels = 128; + const int block_size = 256; + const int total_elements = num_channels * block_size; + + std::vector input_data(total_elements); + for (int i = 0; i < total_elements; i++) { + input_data[i] = std::sin(static_cast(i) * 0.01f) * 5.0f; + } + Tensor input = tf_float.make({num_channels, block_size}, input_data); + + // Create varying scales and zero_points for each channel + std::vector scales(num_channels); + std::vector zero_points(num_channels); + for (int ch = 0; ch < num_channels; ch++) { + scales[ch] = 0.02 + (ch % 10) * 0.001; // Varying scales + zero_points[ch] = 128 + (ch % 5); // Varying zero_points + } + Tensor scale = tf_double.make({num_channels}, scales); + Tensor zero_point = tf_long.make({num_channels}, zero_points); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({num_channels, block_size}); + + // Compute expected values + std::vector expected_data(total_elements); + for (int ch = 0; ch < num_channels; ch++) { + float inv_scale = 1.0f / static_cast(scales[ch]); + int64_t ch_zero_point = zero_points[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = ch * block_size + i; + float val = input_data[idx] * inv_scale; + // Clamp before converting to avoid overflow + val = std::max(-1000.0f, std::min(1000.0f, val)); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + Tensor expected = tfo.make({num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 0, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +} + +TEST(OpQuantizeOutTest, PerChannelAxis1LargeBlocks) { + // Test per-channel quantization with axis=1 and large contiguous blocks + // Shape: [2, 3, 64] with axis=1 (2 batches, 3 channels, 64 elements each) + TensorFactory tf_float; + TensorFactory tf_double; + TensorFactory tf_long; + + const int batch_size = 2; + const int num_channels = 3; + const int block_size = 64; + const int total_elements = batch_size * num_channels * block_size; + + std::vector input_data(total_elements); + for (int i = 0; i < total_elements; i++) { + input_data[i] = static_cast(i % 100) * 0.1f; + } + Tensor input = + tf_float.make({batch_size, num_channels, block_size}, input_data); + + Tensor scale = tf_double.make({num_channels}, {0.05, 0.1, 0.15}); + Tensor zero_point = tf_long.make({num_channels}, {100, 110, 120}); + + int64_t quant_min = 0; + int64_t quant_max = 255; + + TensorFactory tfo; + Tensor out = tfo.zeros({batch_size, num_channels, block_size}); + + // Compute expected values + std::vector expected_data(total_elements); + for (int b = 0; b < batch_size; b++) { + for (int ch = 0; ch < num_channels; ch++) { + double ch_scale = scale.const_data_ptr()[ch]; + int64_t ch_zero_point = zero_point.const_data_ptr()[ch]; + + for (int i = 0; i < block_size; i++) { + int idx = (b * num_channels + ch) * block_size + i; + float val = input_data[idx] / static_cast(ch_scale); + int32_t qval = static_cast(std::nearbyint(val)) + + static_cast(ch_zero_point); + qval = std::min(255, std::max(0, qval)); + expected_data[idx] = static_cast(qval); + } + } + } + Tensor expected = + tfo.make({batch_size, num_channels, block_size}, expected_data); + + quantize_per_channel_out( + input, scale, zero_point, 1, quant_min, quant_max, ScalarType::Byte, out); + + EXPECT_TENSOR_EQ(out, expected); +}