Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized operation on quantized tensors returns wrong values (different rounding) #107030

Closed
Flamefire opened this issue Aug 11, 2023 · 6 comments
Assignees
Labels
oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Flamefire
Copy link
Collaborator

Flamefire commented Aug 11, 2023

🐛 Describe the bug

The following code fails:

import numpy as np
import torch

X = torch.from_numpy(np.full(64+1, 514., dtype=np.float32))
(scale, zero_point, torch_type) = (1028.02, 255, torch.quint8)

assert X.is_contiguous(memory_format=torch.contiguous_format)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
                               dtype=torch_type)

f_min, f_max = 0.0, 1.0
q_min, q_max = torch.iinfo(torch_type).min, torch.iinfo(torch_type).max
output_scale = (f_max - f_min) / (q_max - q_min + 1.0)

qY = torch.ops.quantized.sigmoid(qX, output_scale=output_scale, output_zero_point=0)
print(qY)
assert qY[0] == qY[-1]

In particular the first 64 values are "0.5039" while the remainder are "0.5000". This happens for any remainder not fitting into chunks of 64 values.

Found by reducing an example of a failing test in test_quantization:

======================================================================
FAIL: test_sigmoid (quantization.core.test_quantized_op.TestQuantizedOps)
----------------------------------------------------------------------
Traceback (most recent call last):
<snip>
AssertionError: Quantized tensor-likes are not close!

Mismatched elements: 63 / 75 (84.0%)
Greatest absolute difference: 0.00390625 at index (0, 0, 1) (up to 1e-05 allowed)
Greatest relative difference: 0.0078125 at index (0, 0, 1) (up to 1.3e-06 allowed) : sigmoid - quantized.sigmoid failed: (tensor([[[0.0000, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039]],

        [[0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039]],

        [[0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5039],
         [0.5039, 0.5039, 0.5039, 0.5039, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]], size=(3, 5, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.00390625, zero_point=0) vs. tensor([[[0.0000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]],

        [[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
         [0.5000, 0.5000, 0.5000, 0.5000, 0.5000]]], size=(3, 5, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=0.00390625, zero_point=0))
Falsifying example: test_sigmoid(
    X=(array([[[-261630.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]],

            [[    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]],

            [[    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.],
             [    514.,     514.,     514.,     514.,     514.]]],
           dtype=float32), (1028.0156862745098, 255, torch.quint8)),
    self=<quantization.core.test_quantized_op.TestQuantizedOps testMethod=test_sigmoid>,
)

----------------------------------------------------------------------
Ran 942 tests in 656.469s

FAILED (failures=2, errors=1, skipped=72)

This seems to happen for all PyTorch versions so far and does not depend on the host CPU. I reproduced this even on ppc64le.

Versions

PyTorch version: 2.0.1+cu117
Is debug build: False

OS: CentOS Linux release 7.9.2009 (Core) (x86_64)
GCC version: (GCC) 11.3.0
Clang version: Could not collect
CMake version: version 3.27.1
Libc version: glibc-2.17

Python version: 3.10.4 (main, Oct 6 2022, 14:14:40) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.11.1.el7.x86_64-x86_64-with-glibc2.17

Is XNNPACK available: True

CPU:
Architecture: x86_64
Model name: AMD EPYC 7352 24-Core Processor
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc art rep_good nopl nonstop_tsc extd_apicid aperfmperf eagerfpu pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_l2 cpb cat_l3 cdp_l3 hw_pstate sme retpoline_amd ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif umip overflow_recov succor smca

Versions of relevant libraries:
[pip3] numpy==1.25.2
[pip3] torch==2.0.1
[pip3] triton==2.0.0
[conda] Could not collect

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

@Flamefire Flamefire changed the title Vectorized operation on quantized tensors returns wrong values Vectorized operation on quantized tensors returns wrong values (different rounding) Aug 11, 2023
@Flamefire
Copy link
Collaborator Author

Flamefire commented Aug 11, 2023

I traced the issue to a difference in the de-quantization:

In the current case we have a clipped quantized value: zero_point is 255 for quint8 so any positive value will be quantized as 255 (qX = dX / scale + zp)

For regular de-quantization the formula is dX = scale * (qX - zp) and as we have qX == zp the result is zero.

However the vectorized de-quantization doesn't use this exact formula:

float_vec_return_type dequantize(
Vectorized<float> scale,
Vectorized<float> /*zero_point*/,
Vectorized<float> scale_zp_premul) const {
__m128i int_val0 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 0));
__m128i int_val1 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 1));
__m128i int_val2 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 2));
__m128i int_val3 = _mm_set1_epi64x(_mm256_extract_epi64(vals, 3));
__m256 float_val0 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val0));
__m256 float_val1 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val1));
__m256 float_val2 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val2));
__m256 float_val3 = _mm256_cvtepi32_ps(cvtepu8_epi32(int_val3));
auto val0 =
vec::fmadd(scale, Vectorized<float>(float_val0), scale_zp_premul);
auto val1 =
vec::fmadd(scale, Vectorized<float>(float_val1), scale_zp_premul);
auto val2 =
vec::fmadd(scale, Vectorized<float>(float_val2), scale_zp_premul);
auto val3 =
vec::fmadd(scale, Vectorized<float>(float_val3), scale_zp_premul);
return {val0, val1, val2, val3};
}

So what is used is: fmadd(scale, qX, scale * -zp) = scale * qX + (scale * -zp) with the precalculated value scale * -zp

While mathematically equivalent due to rounding the result is different.

Hence while the non-vectorized part yields the correct de-quantized value "0", the vectorized part yields "0.0112305" which is exactly the rounding error which can be expressed as float(scale * qX) + float(scale * -zp)

The test tries to compensate for similar errors by computing the reference on a dequantized value of the quantized tensor:

This could have caught the issue, however tensor.dequantize uses either FBGEMM or the non-vectorized version dequantize_val yielding correct values not correcting for this issue.

For reference FBGEMM uses:

template <typename T>
float Dequantize(T src, const TensorQuantizationParams& qparams) {
  return qparams.scale * (src - qparams.zero_point);
}

and the vector/multi-elem version simply iterates over this. So it also doesn't suffer from the rounding issue. So I'd argue this is a bug in the vectorized dequantization implementation in PyTorch. Especially suspicious is

auto zero_point_vec = Vectorized<float>((float)zero_point);

I.e. an int64_t is converted into a float which begs the questions: Why is it a 64-bit value (FBGEMM uses 32-bit and dequantize_val also casts it to 32-bit) and casting a 64-bit int to a 32-bit float surely looks like it is loosing quite some information.

So I'd guess expanding the quantized vector to Vectorized<int32_t>, subtract the zero-point then convert to Vectorized<float> and multiply with the scale would be most correct/matching. Also using Vectorized<float> for the values and zero-point will work for the 8-bit quantized values (no loss) but might be an issue for some values of qint32

@mingfeima
Copy link
Collaborator

@Xia-Weiwen could you please take a look at this one ?

@colesbury colesbury added the oncall: quantization Quantization support in PyTorch label Aug 14, 2023
@Xia-Weiwen
Copy link
Collaborator

@Xia-Weiwen could you please take a look at this one ?

I will take a look later.

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 1, 2023
@Flamefire
Copy link
Collaborator Author

Any updates here? This is still an issue in 2.1

@jgong5
Copy link
Collaborator

jgong5 commented Nov 17, 2023

@Xia-Weiwen Any comments?

@Xia-Weiwen
Copy link
Collaborator

Hi @Flamefire. Sorry for the late reply. We plan to fix it by PyTorch 2.2. For the main branch, it will be earlier than that. Thanks!

xunsongh pushed a commit to xunsongh/pytorch that referenced this issue Nov 24, 2023
…gmoid (pytorch#114098)

**Description**
Fix pytorch#107030
Dequantize X by `(x_val - zp) * scale` instead of `x_val * scale + (-zp * scale)` to eliminate rounding error.
Now this overload is used for sigmoid only.

Performance impact:
![image](https://github.com/pytorch/pytorch/assets/12522207/655abd16-7d9d-4a9a-8c59-327ebf39157a)
Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz (Ice Lake)

**Test plan**
`python test_quantization.py TestQuantizedOps.test_sigmoid_dequantize_rounding_error`

Pull Request resolved: pytorch#114098
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

No branches or pull requests

6 participants