diff --git a/examples/xtensa/ops/dequantize_per_tensor.cpp b/examples/xtensa/ops/dequantize_per_tensor.cpp index dcc4ace7898..26b6c71ca96 100644 --- a/examples/xtensa/ops/dequantize_per_tensor.cpp +++ b/examples/xtensa/ops/dequantize_per_tensor.cpp @@ -25,7 +25,15 @@ void dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional& out_dtype, Tensor& out) { + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Expected out dtype to be Float but got %hhd", + out_dtype.value()); + } + float* out_data = out.mutable_data_ptr(); size_t numel = out.numel(); diff --git a/examples/xtensa/ops/functions.yaml b/examples/xtensa/ops/functions.yaml index 07093d3ed24..f89be7bcf9d 100644 --- a/examples/xtensa/ops/functions.yaml +++ b/examples/xtensa/ops/functions.yaml @@ -26,7 +26,7 @@ - arg_meta: null kernel_name: impl::HiFi::quantized_linear_pt2_out -- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 129b9a4ad5b..5d567b04ff7 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -33,6 +33,7 @@ void check_dequantize_per_tensor_args( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional& out_dtype, Tensor& out) { ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || @@ -47,10 +48,11 @@ void check_dequantize_per_tensor_args( "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", static_cast(input.scalar_type())); - ET_CHECK_MSG( - out.scalar_type() == ScalarType::Float, - "out.scalar_type() %" PRId8 " is not supported:", - static_cast(out.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } ET_CHECK_MSG( quant_min <= quant_max, @@ -77,13 +79,15 @@ Tensor& dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { torch::executor::Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_tensor_out"); - check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); // calculate the dequantized output, cast scale to float to match fbgemm // behavior @@ -128,6 +132,7 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, @@ -153,6 +158,7 @@ Tensor& dequantize_per_tensor_tensor_args_out( quant_min, quant_max, dtype, + out_dtype, out); return out; } @@ -165,6 +171,7 @@ Tensor& dequantize_per_channel_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { torch::executor::Error err = resize_tensor(out, input.sizes()); @@ -205,7 +212,8 @@ Tensor& dequantize_per_channel_out( ssize_t(zero_point.numel()), ssize_t(input.size(axis))); - check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); // a list contains all dimensions except axis int64_t dims[input.dim() - 1]; @@ -281,10 +289,19 @@ Tensor& dequantize_per_channel_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { (void)context; return dequantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); + input, + scale, + zero_point, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); } Tensor& dequantize_per_tensor_out( @@ -295,12 +312,13 @@ Tensor& dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper (void)context; return dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } Tensor& dequantize_per_tensor_tensor_args_out( @@ -311,12 +329,13 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper (void)context; return dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } } // namespace native diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index b6103343f62..6f4a0a8fc97 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -10,13 +10,13 @@ - arg_meta: null kernel_name: torch::executor::choose_qparams_tensor_out -- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: torch::executor::dequantize_per_tensor_out -- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null @@ -28,7 +28,7 @@ - arg_meta: null kernel_name: torch::executor::quantize_per_channel_out -- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/kernels/quantized/test/op_add_test.cpp b/kernels/quantized/test/op_add_test.cpp index 92c6fa2052b..a48ba10c66e 100644 --- a/kernels/quantized/test/op_add_test.cpp +++ b/kernels/quantized/test/op_add_test.cpp @@ -20,6 +20,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::RuntimeContext; using exec_aten::Scalar; using exec_aten::ScalarType; @@ -190,6 +191,8 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { Tensor qinput2 = tfo.zeros({3, 5}); Tensor qoutput = tfo.zeros({3, 5}); + optional out_dtype = optional(); + RuntimeContext context{}; // q -> qadd -> dq // 3.5 / 0.5 + 1 = 8 @@ -235,6 +238,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, reference_op_output); // now get results for q -> dq -> fp add -> q -> dq @@ -245,6 +249,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, dq_input1); dequantize_per_tensor_out( @@ -254,6 +259,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, dq_input2); add_out(context, dq_input1, dq_input2, 1.0, fp_output); @@ -274,6 +280,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, reference_pattern_output); Tensor expected = tf.full({3, 5}, 7.0); diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 57225e45767..1004126d045 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -18,6 +18,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -43,7 +44,14 @@ void test_dtype() { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 35); dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, DTYPE, out); + input, + scale, + zero_point, + quant_min, + quant_max, + DTYPE, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -66,7 +74,14 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 31.5); dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -87,7 +102,14 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 31.5); dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -116,6 +138,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { quant_min, quant_max, ScalarType::Byte, + optional(), out); EXPECT_TENSOR_EQ(out, expected); @@ -136,6 +159,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { quant_min, quant_max, ScalarType::Byte, + optional(), out); EXPECT_TENSOR_EQ(out, expected); diff --git a/kernels/quantized/test/op_embedding_test.cpp b/kernels/quantized/test/op_embedding_test.cpp index 49605977cc3..76114561e53 100644 --- a/kernels/quantized/test/op_embedding_test.cpp +++ b/kernels/quantized/test/op_embedding_test.cpp @@ -20,6 +20,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::RuntimeContext; using exec_aten::Scalar; using exec_aten::ScalarType; @@ -149,6 +150,7 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + optional(), weight); embedding_out(