Skip to content

Commit

Permalink
Overload vec::dequantize to eliminate rounding error for quantized si…
Browse files Browse the repository at this point in the history
…gmoid

[ghstack-poisoned]
  • Loading branch information
Xia-Weiwen committed Nov 20, 2023
1 parent 5a96a42 commit a7ba41b
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 3 deletions.
54 changes: 54 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vec256_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,13 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m256 float_vals = _mm256_cvtepi32_ps(vals);
return {(Vectorized<float>(float_vals) - zero_point) * scale};
}

static Vectorized<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
Expand Down Expand Up @@ -520,6 +527,26 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
return {val0, val1, val2, val3};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
__m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));

__m256 float_val0 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val0));
__m256 float_val1 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val1));
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val2));
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepi8_epi32(int_val3));

auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
return {val0, val1, val2, val3};
}

static Vectorized<c10::qint8> quantize(
const float_vec_return_type& rhs,
float /*scale*/,
Expand Down Expand Up @@ -698,6 +725,26 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
return {val0, val1, val2, val3};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
__m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));

__m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
__m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));

auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
return {val0, val1, val2, val3};
}

static Vectorized<c10::quint8> quantize(
const float_vec_return_type& rhs,
float /*scale*/,
Expand Down Expand Up @@ -853,6 +900,13 @@ struct VectorizedQuantizedConverter {
return rv;
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
Vectorized<float> scale_zp_premul;
return dequantize(scale, zero_point, scale_zp_premul);
}

protected:
VectorizedQuantizedConverter() {}
};
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/vsx/vec256_qint32_vsx.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,20 @@ struct Vectorized<c10::qint32> {
vec_madd(scale_vec1, float_vals1, scale_zp_premul1)}};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
vfloat32 float_vals0 = vec_float(_vec0);
vfloat32 float_vals1 = vec_float(_vec1);
vfloat32 scale_vec0 = scale.vec0();
vfloat32 scale_vec1 = scale.vec1();
vfloat32 zero_point0 = zero_point.vec0();
vfloat32 zero_point1 = zero_point.vec1();
return {Vectorized<float>{
(float_vals0 - zero_point0) * scale_vec0,
(float_vals1 - zero_point1) * scale_vec1}};
}

static Vectorized<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h
Original file line number Diff line number Diff line change
Expand Up @@ -1730,6 +1730,16 @@ struct Vectorized<T, std::enable_if_t<is_zarch_implemented_quant<T>()>> {
return {fmadd(scale, float_val, scale_zp_premul)};
}

template <
typename U = T,
std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
auto float_val = convert_to_float(_vec);
return {(float_val - zero_point) * scale};
}

template <
typename U = T,
std::enable_if_t<Vectorized<U>::float_num_vecs() == 1, int> = 0>
Expand Down
55 changes: 55 additions & 0 deletions aten/src/ATen/cpu/vec/vec512/vec512_qint.h
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ struct Vectorized<c10::qint32> : public Vectorizedqi {
return {vec::fmadd(scale, Vectorized<float>(float_vals), scale_zp_premul)};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m512 float_vals = _mm512_cvtepi32_ps(vals);
return {(Vectorized<float>(float_vals) - zero_point) * scale};
}

static Vectorized<c10::qint32> quantize(
const float_vec_return_type& rhs,
float scale,
Expand Down Expand Up @@ -531,6 +538,26 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
return {val0, val1, val2, val3};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);

__m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
__m512 float_val2 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val2));
__m512 float_val3 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val3));

auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;
return {val0, val1, val2, val3};
}

static Vectorized<c10::qint8> quantize(
const float_vec_return_type& rhs,
float scale,
Expand Down Expand Up @@ -708,6 +735,27 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
return {val0, val1, val2, val3};
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);

__m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
__m512 float_val2 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val2));
__m512 float_val3 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val3));

auto val0 = (Vectorized<float>(float_val0) - zero_point) * scale;
auto val1 = (Vectorized<float>(float_val1) - zero_point) * scale;
auto val2 = (Vectorized<float>(float_val2) - zero_point) * scale;
auto val3 = (Vectorized<float>(float_val3) - zero_point) * scale;

return {val0, val1, val2, val3};
}

static Vectorized<c10::quint8> quantize(
const float_vec_return_type& rhs,
float scale,
Expand Down Expand Up @@ -865,6 +913,13 @@ struct VectorizedQuantizedConverter {
return rv;
}

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> zero_point) const {
Vectorized<float> scale_zp_premul;
return dequantize(scale, zero_point, scale_zp_premul);
}

protected:
VectorizedQuantizedConverter() {}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,6 @@ void qsigmoid_kernel(
float scale = qx.q_scale();
auto scale_vec = Vectorized<float>(scale);
auto zero_point_vec = Vectorized<float>((float)zero_point);
auto scale_neg_zp_premul_vec = scale_vec * zero_point_vec.neg();

AT_DISPATCH_QINT_TYPES(qx.scalar_type(), "qsigmoid", [&]() {
float inv_output_scale = 1.0 / output_scale;
Expand All @@ -861,8 +860,7 @@ void qsigmoid_kernel(
output_scale, output_zero_point, value_dy);
},
[&](Vec value_qx) -> Vec {
auto value_dx = value_qx.dequantize(
scale_vec, zero_point_vec, scale_neg_zp_premul_vec);
auto value_dx = value_qx.dequantize(scale_vec, zero_point_vec);
for (auto & value : value_dx) {
value = value.neg();
value = value.exp();
Expand Down
17 changes: 17 additions & 0 deletions test/quantization/core/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,23 @@ def test_sigmoid(self, X):
]
self._test_activation_function(X, 'sigmoid', sigmoid_test_configs)

@skipIfNoFBGEMM
def test_sigmoid_dequantize_rounding_error(self):
# issue #107030
sigmoid_test_configs = [
{
'quantized_fn': [
torch.ops.quantized.sigmoid
],
'reference_fn': torch.sigmoid,
'output_range': (0.0, 1.0),
'change_zero_point': True,
'output_is_observed': True,
}
]
X = (np.full(64, 514., dtype=np.float32), (1028.02, 255, torch.quint8))
self._test_activation_function(X, 'sigmoid', sigmoid_test_configs)

"""Tests the correctness of the quantized::hardsigmoid op."""
@override_qengines
def test_qhardsigmoid(self):
Expand Down

0 comments on commit a7ba41b

Please sign in to comment.