diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 836af474015..f4e11bda5c4 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -168,6 +168,19 @@ Tensor& dequantize_per_tensor_tensor_args_out( return out; } +float get_scale(const Tensor& scale, size_t channel_ix) { + ET_CHECK_MSG( + (scale.scalar_type() == ScalarType::Double) || + (scale.scalar_type() == ScalarType::Float), + "scale.scalar_type() %" PRId8 " is not double or float type", + static_cast(scale.scalar_type())); + if (scale.scalar_type() == ScalarType::Double) { + return static_cast(scale.const_data_ptr()[channel_ix]); + } else { + return scale.const_data_ptr()[channel_ix]; + } +} + Tensor& dequantize_per_channel_out( const Tensor& input, const Tensor& scale, @@ -195,11 +208,6 @@ Tensor& dequantize_per_channel_out( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); - ET_CHECK_MSG( - scale.scalar_type() == ScalarType::Float, - "scale.scalar_type() %" PRId8 " is not float type", - static_cast(scale.scalar_type())); - ET_CHECK_MSG( scale.numel() == input.size(axis), "scale.numel() %zd != input.size(axis) %zd", @@ -232,7 +240,6 @@ Tensor& dequantize_per_channel_out( dims[i] = i + 1; } } - const float* scale_data = scale.const_data_ptr(); const int64_t* zero_point_data; if (opt_zero_points.has_value()) { zero_point_data = opt_zero_points.value().const_data_ptr(); @@ -260,11 +267,11 @@ Tensor& dequantize_per_channel_out( axis == 0, "Axis must be 0 for a single dimensional tensors"); \ const optional dim; \ apply_over_dim( \ - [input_data_ptr, out_data_ptr, scale_data, zero_point_data]( \ + [input_data_ptr, out_data_ptr, zero_point_data, &scale]( \ size_t numel, size_t stride, size_t base_ix) { \ for (size_t i = 0; i < numel; i++) { \ size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ + float _scale = get_scale(scale, current_ix); \ int64_t zero_point = 0; \ if (zero_point_data != nullptr) { \ zero_point = zero_point_data[current_ix]; \ @@ -280,7 +287,7 @@ Tensor& dequantize_per_channel_out( break; \ } \ for (size_t channel_ix = 0; channel_ix < input.size(axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ + float _scale = get_scale(scale, channel_ix); \ int64_t _zero_point = 0; \ if (zero_point_data != nullptr) { \ _zero_point = zero_point_data[channel_ix]; \ diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 4cdebb662b7..8aa17772baf 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -57,10 +58,12 @@ void test_dtype() { } TEST(OpDequantizeOutTest, AllDtypesSupported) { + et_pal_init(); test_dtype(); } TEST(OpDequantizeOutTest, NonWholeNumbers) { + et_pal_init(); TensorFactory tf; Tensor input = tf.full({3, 5}, 100); @@ -87,6 +90,7 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) { } TEST(OpDequantizeOutTest, TensorArgOverload) { + et_pal_init(); TensorFactory tf_byte; TensorFactory tf_double; TensorFactory tf_long; @@ -115,12 +119,13 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { } TEST(OpDequantizeOutTest, DequantizePerChannel) { + et_pal_init(); TensorFactory tf_byte; - TensorFactory tf_float; + TensorFactory tf_double; TensorFactory tf_long; Tensor input = tf_byte.full({3, 2}, 100); - Tensor scale = tf_float.make({2}, {0.5, 1}); + Tensor scale = tf_double.make({2}, {0.5, 1}); Tensor zero_point = tf_long.make({2}, {30, 60}); int64_t quant_min = 0; int64_t quant_max = 255; @@ -145,7 +150,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { // Test with a different axis out = tfo.zeros({3, 2}); - scale = tf_float.make({3}, {0.5, 0.75, 1}); + scale = tf_double.make({3}, {0.5, 0.75, 1}); zero_point = tf_long.make({3}, {30, 50, 60}); // (100 - 30) * 0.5 // (100 - 50) * 0.75 @@ -167,7 +172,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { // Test with a different axis out = tfo.zeros({3}); input = tf_byte.make({3}, {100, 100, 100}); - scale = tf_float.make({3}, {0.5, 0.75, 1}); + scale = tf_double.make({3}, {0.5, 0.75, 1}); zero_point = tf_long.make({3}, {30, 50, 60}); // (100 - 30) * 0.5 // (100 - 50) * 0.75