From a7ba41b02280601ba480d43a2567ae99f6e037f6 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 20 Nov 2023 14:48:30 +0800 Subject: [PATCH] Overload vec::dequantize to eliminate rounding error for quantized sigmoid [ghstack-poisoned] --- aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 54 ++++++++++++++++++ .../cpu/vec/vec256/vsx/vec256_qint32_vsx.h | 14 +++++ .../ATen/cpu/vec/vec256/zarch/vec256_zarch.h | 10 ++++ aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 55 +++++++++++++++++++ .../cpu/kernels/QuantizedOpKernels.cpp | 4 +- test/quantization/core/test_quantized_op.py | 17 ++++++ 6 files changed, 151 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index 16550a6af20bc..ee14de69324fa 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -305,6 +305,13 @@ struct Vectorized : public Vectorizedqi { return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m256 float_vals = _mm256_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -520,6 +527,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float /*scale*/, @@ -698,6 +725,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0)); + __m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1)); + __m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2)); + __m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3)); + + __m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float /*scale*/, @@ -853,6 +900,13 @@ struct VectorizedQuantizedConverter { return rv; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h index a85730c9a6df8..746a5e27a5c10 100644 --- a/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h +++ b/aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h @@ -121,6 +121,20 @@ struct Vectorized { vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + vfloat32 float_vals0 = vec_float(_vec0); + vfloat32 float_vals1 = vec_float(_vec1); + vfloat32 scale_vec0 = scale.vec0(); + vfloat32 scale_vec1 = scale.vec1(); + vfloat32 zero_point0 = zero_point.vec0(); + vfloat32 zero_point1 = zero_point.vec1(); + return {Vectorized{ + (float_vals0 - zero_point0) * scale_vec0, + (float_vals1 - zero_point1) * scale_vec1}}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index 25ca208ee24cb..70b130421cdfd 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -1730,6 +1730,16 @@ struct Vectorized()>> { return {fmadd(scale, float_val, scale_zp_premul)}; } + template < + typename U = T, + std::enable_if_t::float_num_vecs() == 1, int> = 0> + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + auto float_val = convert_to_float(_vec); + return {(float_val - zero_point) * scale}; + } + template < typename U = T, std::enable_if_t::float_num_vecs() == 1, int> = 0> diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h index 493573ccacf10..b03da5d2c3e95 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h @@ -317,6 +317,13 @@ struct Vectorized : public Vectorizedqi { return {vec::fmadd(scale, Vectorized(float_vals), scale_zp_premul)}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m512 float_vals = _mm512_cvtepi32_ps(vals); + return {(Vectorized(float_vals) - zero_point) * scale}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -531,6 +538,26 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -708,6 +735,27 @@ struct Vectorized : public Vectorizedqi { return {val0, val1, val2, val3}; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]); + __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]); + __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]); + __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]); + + __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0)); + __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1)); + __m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2)); + __m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3)); + + auto val0 = (Vectorized(float_val0) - zero_point) * scale; + auto val1 = (Vectorized(float_val1) - zero_point) * scale; + auto val2 = (Vectorized(float_val2) - zero_point) * scale; + auto val3 = (Vectorized(float_val3) - zero_point) * scale; + + return {val0, val1, val2, val3}; + } + static Vectorized quantize( const float_vec_return_type& rhs, float scale, @@ -865,6 +913,13 @@ struct VectorizedQuantizedConverter { return rv; } + float_vec_return_type dequantize( + Vectorized scale, + Vectorized zero_point) const { + Vectorized scale_zp_premul; + return dequantize(scale, zero_point, scale_zp_premul); + } + protected: VectorizedQuantizedConverter() {} }; diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 479b023a56df0..373ef4af33da7 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -836,7 +836,6 @@ void qsigmoid_kernel( float scale = qx.q_scale(); auto scale_vec = Vectorized(scale); auto zero_point_vec = Vectorized((float)zero_point); - auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg(); AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() { float inv_output_scale = 1.0 / output_scale; @@ -861,8 +860,7 @@ void qsigmoid_kernel( output_scale, output_zero_point, value_dy); }, [&](Vec value_qx) -> Vec { - auto value_dx = value_qx.dequantize( - scale_vec, zero_point_vec, scale_neg_zp_premul_vec); + auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec); for (auto & value : value_dx) { value = value.neg(); value = value.exp(); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 15aaee96da3c0..59784a63d3efc 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -328,6 +328,23 @@ def test_sigmoid(self, X): ] self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + @skipIfNoFBGEMM + def test_sigmoid_dequantize_rounding_error(self): + # issue #107030 + sigmoid_test_configs = [ + { + 'quantized_fn': [ + torch.ops.quantized.sigmoid + ], + 'reference_fn': torch.sigmoid, + 'output_range': (0.0, 1.0), + 'change_zero_point': True, + 'output_is_observed': True, + } + ] + X = (np.full(64, 514., dtype=np.float32), (1028.02, 255, torch.quint8)) + self._test_activation_function(X, 'sigmoid', sigmoid_test_configs) + """Tests the correctness of the quantized::hardsigmoid op.""" @override_qengines def test_qhardsigmoid(self):