From 91acd1d61cdd1ecac9556911b672376b0d7289e4 Mon Sep 17 00:00:00 2001 From: kausik Date: Mon, 11 Mar 2024 17:50:31 -0700 Subject: [PATCH] Modify signature of dequantize ops for decomposed quantized Tensor (#2308) Summary: X-link: https://github.com/pytorch/pytorch/pull/121450 Note: The initial purpose of this PR is to draw suggestion and feedback regarding better alternative, if any. At present, dequantize op for decomposed quantized Tensor representation e.g. dequantize_per_tensor() assumes the output dtype as torch.float and hence, it does not have the output dtype in its operator argument list. However, this op signature becomes unusable when the assumption breaks. Because, in case the output dtype is different from torch.float, there is no way to specify the same during dequantization. This change is aimed at generalizing the signature of dequantize op like dequantize_per_tensor() for wider use-cases where the output dtype can be different from torch.float and needs to passed during dequantization. The proposal is to use an additional argument named 'output_dtype' to solve the problem. However, we would also like to have suggestion and feedback regarding any better alternative that can be used instead. cc jerryzh168 jianyuh raghuramank100 jamesr66a vkuzo jgong5 Xia-Weiwen leslie-fang-intel X-link: https://github.com/pytorch/pytorch/pull/119173 Reviewed By: digantdesai Differential Revision: D53590486 Pulled By: manuelcandales --- examples/xtensa/ops/dequantize_per_tensor.cpp | 8 ++++ examples/xtensa/ops/functions.yaml | 2 +- kernels/quantized/cpu/op_dequantize.cpp | 37 ++++++++++++++----- kernels/quantized/quantized.yaml | 6 +-- kernels/quantized/test/op_add_test.cpp | 7 ++++ kernels/quantized/test/op_dequantize_test.cpp | 30 +++++++++++++-- kernels/quantized/test/op_embedding_test.cpp | 2 + 7 files changed, 76 insertions(+), 16 deletions(-) diff --git a/examples/xtensa/ops/dequantize_per_tensor.cpp b/examples/xtensa/ops/dequantize_per_tensor.cpp index dcc4ace7898..26b6c71ca96 100644 --- a/examples/xtensa/ops/dequantize_per_tensor.cpp +++ b/examples/xtensa/ops/dequantize_per_tensor.cpp @@ -25,7 +25,15 @@ void dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional& out_dtype, Tensor& out) { + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out_dtype.value() == ScalarType::Float, + "Expected out dtype to be Float but got %hhd", + out_dtype.value()); + } + float* out_data = out.mutable_data_ptr(); size_t numel = out.numel(); diff --git a/examples/xtensa/ops/functions.yaml b/examples/xtensa/ops/functions.yaml index 07093d3ed24..f89be7bcf9d 100644 --- a/examples/xtensa/ops/functions.yaml +++ b/examples/xtensa/ops/functions.yaml @@ -26,7 +26,7 @@ - arg_meta: null kernel_name: impl::HiFi::quantized_linear_pt2_out -- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: xtensa::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 129b9a4ad5b..5d567b04ff7 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -33,6 +33,7 @@ void check_dequantize_per_tensor_args( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional& out_dtype, Tensor& out) { ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || @@ -47,10 +48,11 @@ void check_dequantize_per_tensor_args( "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", static_cast(input.scalar_type())); - ET_CHECK_MSG( - out.scalar_type() == ScalarType::Float, - "out.scalar_type() %" PRId8 " is not supported:", - static_cast(out.scalar_type())); + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } ET_CHECK_MSG( quant_min <= quant_max, @@ -77,13 +79,15 @@ Tensor& dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { torch::executor::Error err = resize_tensor(out, input.sizes()); ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_tensor_out"); - check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); // calculate the dequantized output, cast scale to float to match fbgemm // behavior @@ -128,6 +132,7 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, @@ -153,6 +158,7 @@ Tensor& dequantize_per_tensor_tensor_args_out( quant_min, quant_max, dtype, + out_dtype, out); return out; } @@ -165,6 +171,7 @@ Tensor& dequantize_per_channel_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { torch::executor::Error err = resize_tensor(out, input.sizes()); @@ -205,7 +212,8 @@ Tensor& dequantize_per_channel_out( ssize_t(zero_point.numel()), ssize_t(input.size(axis))); - check_dequantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); // a list contains all dimensions except axis int64_t dims[input.dim() - 1]; @@ -281,10 +289,19 @@ Tensor& dequantize_per_channel_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { (void)context; return dequantize_per_channel_out( - input, scale, zero_point, axis, quant_min, quant_max, dtype, out); + input, + scale, + zero_point, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); } Tensor& dequantize_per_tensor_out( @@ -295,12 +312,13 @@ Tensor& dequantize_per_tensor_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper (void)context; return dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } Tensor& dequantize_per_tensor_tensor_args_out( @@ -311,12 +329,13 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, + exec_aten::optional out_dtype, Tensor& out) { // TODO(larryliu): Add a context arg to the real op function and remove this // wrapper (void)context; return dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, dtype, out); + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } } // namespace native diff --git a/kernels/quantized/quantized.yaml b/kernels/quantized/quantized.yaml index b6103343f62..6f4a0a8fc97 100644 --- a/kernels/quantized/quantized.yaml +++ b/kernels/quantized/quantized.yaml @@ -10,13 +10,13 @@ - arg_meta: null kernel_name: torch::executor::choose_qparams_tensor_out -- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null kernel_name: torch::executor::dequantize_per_tensor_out -- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_tensor.Tensor_out(Tensor input, Tensor scale, Tensor zero_point, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null @@ -28,7 +28,7 @@ - arg_meta: null kernel_name: torch::executor::quantize_per_channel_out -- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) +- func: quantized_decomposed::dequantize_per_channel.out(Tensor input, Tensor scales, Tensor zero_points, int axis, int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None, Tensor(a!) out) -> Tensor(a!) variants: function kernels: - arg_meta: null diff --git a/kernels/quantized/test/op_add_test.cpp b/kernels/quantized/test/op_add_test.cpp index 92c6fa2052b..a48ba10c66e 100644 --- a/kernels/quantized/test/op_add_test.cpp +++ b/kernels/quantized/test/op_add_test.cpp @@ -20,6 +20,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::RuntimeContext; using exec_aten::Scalar; using exec_aten::ScalarType; @@ -190,6 +191,8 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { Tensor qinput2 = tfo.zeros({3, 5}); Tensor qoutput = tfo.zeros({3, 5}); + optional out_dtype = optional(); + RuntimeContext context{}; // q -> qadd -> dq // 3.5 / 0.5 + 1 = 8 @@ -235,6 +238,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, reference_op_output); // now get results for q -> dq -> fp add -> q -> dq @@ -245,6 +249,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, dq_input1); dequantize_per_tensor_out( @@ -254,6 +259,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, dq_input2); add_out(context, dq_input1, dq_input2, 1.0, fp_output); @@ -274,6 +280,7 @@ TEST(OpQuantizeAddTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + out_dtype, reference_pattern_output); Tensor expected = tf.full({3, 5}, 7.0); diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 57225e45767..1004126d045 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -18,6 +18,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -43,7 +44,14 @@ void test_dtype() { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 35); dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, DTYPE, out); + input, + scale, + zero_point, + quant_min, + quant_max, + DTYPE, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -66,7 +74,14 @@ TEST(OpDequantizeOutTest, NonWholeNumbers) { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 31.5); dequantize_per_tensor_out( - input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -87,7 +102,14 @@ TEST(OpDequantizeOutTest, TensorArgOverload) { // (100 - 30) * 0.5 Tensor expected = tfo.full({3, 5}, 31.5); dequantize_per_tensor_tensor_args_out( - input, scale, zero_point, quant_min, quant_max, ScalarType::Byte, out); + input, + scale, + zero_point, + quant_min, + quant_max, + ScalarType::Byte, + optional(), + out); EXPECT_TENSOR_EQ(out, expected); } @@ -116,6 +138,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { quant_min, quant_max, ScalarType::Byte, + optional(), out); EXPECT_TENSOR_EQ(out, expected); @@ -136,6 +159,7 @@ TEST(OpDequantizeOutTest, DequantizePerChannel) { quant_min, quant_max, ScalarType::Byte, + optional(), out); EXPECT_TENSOR_EQ(out, expected); diff --git a/kernels/quantized/test/op_embedding_test.cpp b/kernels/quantized/test/op_embedding_test.cpp index 49605977cc3..76114561e53 100644 --- a/kernels/quantized/test/op_embedding_test.cpp +++ b/kernels/quantized/test/op_embedding_test.cpp @@ -20,6 +20,7 @@ using namespace ::testing; using exec_aten::ArrayRef; +using exec_aten::optional; using exec_aten::RuntimeContext; using exec_aten::Scalar; using exec_aten::ScalarType; @@ -149,6 +150,7 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) { quant_min, quant_max, ScalarType::Byte, + optional(), weight); embedding_out(