From fab6e5f390b6926637ac6396bab0542641095065 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 12 Mar 2024 05:37:54 -0700 Subject: [PATCH] optional zero points on dequantize per channel Summary: bypass-github-export-checks Reviewed By: mikekgfb Differential Revision: D54709217 --- kernels/quantized/cpu/op_dequantize.cpp | 39 ++++++++++++++++--------- kernels/quantized/quantized.yaml | 2 +- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 5d567b04ff7..b90adccaf24 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -166,7 +166,7 @@ Tensor& dequantize_per_tensor_tensor_args_out( Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, - const Tensor& zero_point, + const optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, @@ -201,16 +201,19 @@ Tensor& dequantize_per_channel_out( ssize_t(scale.numel()), ssize_t(input.size(axis))); - ET_CHECK_MSG( - zero_point.scalar_type() == ScalarType::Long, - "zero_point.scalar_type() %" PRId8 " is not integer type", - static_cast(zero_point.scalar_type())); + if (opt_zero_points.has_value()) { + auto zero_point = opt_zero_points.value(); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); - ET_CHECK_MSG( - zero_point.numel() == input.size(axis), - "zero_point.numel() %zd != input.size(axis) %zd", - ssize_t(zero_point.numel()), - ssize_t(input.size(axis))); + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + ssize_t(zero_point.numel()), + ssize_t(input.size(axis))); + } check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); @@ -225,7 +228,12 @@ Tensor& dequantize_per_channel_out( } } const double* scale_data = scale.const_data_ptr(); - const int64_t* zero_point_data = zero_point.const_data_ptr(); + const int64_t* zero_point_data; + if (opt_zero_points.has_value()) { + zero_point_data = opt_zero_points.value().const_data_ptr(); + } else { + zero_point_data = nullptr; + } exec_aten::optional> optional_dim_list{ exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; @@ -242,7 +250,10 @@ Tensor& dequantize_per_channel_out( case ScalarType::out_dtype: \ for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ double _scale = scale_data[channel_ix]; \ - int64_t _zero_point = zero_point_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ apply_over_dim_list( \ [input, out, _scale, _zero_point](size_t in_ix) { \ out.mutable_data_ptr()[in_ix] = static_cast( \ @@ -284,7 +295,7 @@ Tensor& dequantize_per_channel_out( RuntimeContext& context, const Tensor& input, const Tensor& scale, - const Tensor& zero_point, + const optional& opt_zero_points, int64_t axis, int64_t quant_min, int64_t quant_max, @@ -295,7 +306,7 @@ Tensor& dequantize_per_channel_out( return dequantize_per_channel_out( input, scale, - zero_point, + opt_zero_points, axis, quant_min, quant_max, diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index 6f4a0a8fc97..484641318b4 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -28,7 +28,7 @@ - arg_meta: null kernel_name: torch::executor::quantize_per_channel_out -- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor? zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null