Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the CUDA fake quantize logic consistent with CPU fake quantize logic #49808

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
18 changes: 10 additions & 8 deletions aten/src/ATen/native/quantized/cuda/affine_quantizer.cu
Expand Up @@ -25,14 +25,16 @@ void quantize_tensor_per_tensor_affine_cuda(
.add_input(qtensor)
.build();

gpu_kernel(iter,
[=] GPU_LAMBDA (float raw_val, scalar_t quantized_val) -> scalar_t {
int64_t qvalue = static_cast<int64_t>(nearbyint(raw_val / scale + zero_point));
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
gpu_kernel(
iter,
[=] GPU_LAMBDA(float raw_val, scalar_t quantized_val) -> scalar_t {
int64_t qvalue =
static_cast<int64_t>(nearbyint(raw_val / scale) + zero_point);
qvalue = std::max<int64_t>(qvalue, qmin);
qvalue = std::min<int64_t>(qvalue, qmax);
quantized_val.val_ = qvalue;
return quantized_val;
});
});
}

Expand Down
39 changes: 19 additions & 20 deletions aten/src/ATen/native/quantized/cuda/fake_quantize_core.cu
Expand Up @@ -34,17 +34,16 @@ void fake_quantize_tensor_kernel_cuda(
.add_output(output)
.add_input(input)
.build();
gpu_kernel(iter,
[=] GPU_LAMBDA (float input_val) -> float {
return (fminf(
gpu_kernel(iter, [=] GPU_LAMBDA(float input_val) -> float {
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(std::nearbyint(
input_val * inv_scale + zero_point)))) -
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) + zero_point))) -
zero_point) *
scale;
});
scale;
});
}

void fake_quantize_grad_tensor_kernel_cuda(
Expand All @@ -63,11 +62,10 @@ void fake_quantize_grad_tensor_kernel_cuda(
.add_input(output_grad)
.add_input(input)
.build();
gpu_kernel(iter,
[=] GPU_LAMBDA (float dy, float x) -> float {
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
gpu_kernel(iter, [=] GPU_LAMBDA(float dy, float x) -> float {
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}

void _fake_quantize_grad_learnable_tensor_kernel_cuda(
Expand All @@ -82,7 +80,7 @@ void _fake_quantize_grad_learnable_tensor_kernel_cuda(
gpu_kernel_multiple_outputs(
iter, [=] GPU_LAMBDA (float XInput, float dYInput) -> thrust::tuple<float, float, float> {
float dXOutput, dZeroPointOutput, dScaleOutput;
int64_t xq = std::nearbyint(zero_point + XInput * inv_scale);
int64_t xq = std::nearbyint(XInput * inv_scale) + zero_point;
dXOutput = dYInput * (xq >= quant_min && xq <= quant_max);
xq = std::max(std::min(xq, quant_max), quant_min);
float xfq = static_cast<float>((xq - zero_point) * scale);
Expand All @@ -108,12 +106,13 @@ void fake_quant_per_channel_cuda(TensorIterator &iter, int64_t quant_min, int64_
[=] GPU_LAMBDA (float input_val, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
return (fminf(
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(std::nearbyint(
input_val * inv_scale + zero_point)))) -
zero_point) *
quant_max,
fmaxf(
quant_min,
static_cast<int64_t>(
std::nearbyint(input_val * inv_scale) +
zero_point))) -
zero_point) *
scale;
});
}
Expand All @@ -122,7 +121,7 @@ void fake_quant_grad_per_channel_cuda(TensorIterator &iter, int64_t quant_min, i
gpu_kernel(iter,
[=] GPU_LAMBDA (float x, float dy, float scale, int64_t zero_point) -> float {
float inv_scale = 1.0f / scale;
int64_t Xq = std::nearbyint(x * inv_scale + zero_point);
int64_t Xq = std::nearbyint(x * inv_scale) + zero_point;
return (Xq >= quant_min && Xq <= quant_max) * dy;
});
}
Expand Down