From 74dad77ad03f09db7090bfe2ad7c1c04c58149cc Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Fri, 20 Sep 2024 15:50:30 -0700 Subject: [PATCH] Fix dequantize per channel to handle double scale type Differential Revision: [D62301839](https://our.internmc.facebook.com/intern/diff/D62301839/) [ghstack-poisoned] --- kernels/quantized/cpu/op_dequantize.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 836af474015..3b77b4c23f2 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -196,7 +196,7 @@ Tensor& dequantize_per_channel_out( "Failed to resize out Tensor in dequantize_per_channel_out"); ET_CHECK_MSG( - scale.scalar_type() == ScalarType::Float, + scale.scalar_type() == ScalarType::Double, "scale.scalar_type() %" PRId8 " is not float type", static_cast(scale.scalar_type())); @@ -232,7 +232,7 @@ Tensor& dequantize_per_channel_out( dims[i] = i + 1; } } - const float* scale_data = scale.const_data_ptr(); + const double* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data; if (opt_zero_points.has_value()) { zero_point_data = opt_zero_points.value().const_data_ptr(); @@ -264,7 +264,7 @@ Tensor& dequantize_per_channel_out( size_t numel, size_t stride, size_t base_ix) { \ for (size_t i = 0; i < numel; i++) { \ size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ + float _scale = static_cast(scale_data[current_ix]); \ int64_t zero_point = 0; \ if (zero_point_data != nullptr) { \ zero_point = zero_point_data[current_ix]; \ @@ -280,7 +280,7 @@ Tensor& dequantize_per_channel_out( break; \ } \ for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ + float _scale = static_cast(scale_data[channel_ix]); \ int64_t _zero_point = 0; \ if (zero_point_data != nullptr) { \ _zero_point = zero_point_data[channel_ix]; \