diff --git a/kernels/quantized/cpu/op_dequantize.cpp b/kernels/quantized/cpu/op_dequantize.cpp index 8973d69cc31..9f8a365b9c0 100644 --- a/kernels/quantized/cpu/op_dequantize.cpp +++ b/kernels/quantized/cpu/op_dequantize.cpp @@ -38,6 +38,7 @@ void check_dequantize_per_tensor_args( ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Bits16 || input.scalar_type() == ScalarType::Short || input.scalar_type() == ScalarType::Int, "input.scalar_type() %" PRId8 " is not supported:", @@ -118,6 +119,7 @@ Tensor& dequantize_per_tensor_out( switch (input.scalar_type()) { ET_FORALL_INT_TYPES(CALCULATE_INT_TYPE); + CALCULATE_INT_TYPE(uint16_t, Bits16); default: ET_CHECK_MSG( false, @@ -312,6 +314,7 @@ Tensor& dequantize_per_channel_out( switch (input.scalar_type()) { ET_FORALL_INT_TYPES(CALCULATE_FLOAT_TYPE); + CALCULATE_INT_TYPE(uint16_t, Bits16); default: ET_CHECK_MSG( false, diff --git a/kernels/quantized/cpu/op_quantize.cpp b/kernels/quantized/cpu/op_quantize.cpp index 9e95b11d592..74654ffafee 100644 --- a/kernels/quantized/cpu/op_quantize.cpp +++ b/kernels/quantized/cpu/op_quantize.cpp @@ -57,6 +57,9 @@ void check_quantize_per_tensor_args( static_cast(std::numeric_limits::min()); quant_max_upper_bound = static_cast(std::numeric_limits::max()); + } else if (dtype == ScalarType::Bits16) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); } else if (dtype == ScalarType::Short) { quant_min_lower_bound = std::numeric_limits::min(); quant_max_upper_bound = std::numeric_limits::max(); @@ -135,6 +138,7 @@ Tensor& quantize_per_tensor_out( case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ ET_FORALL_INT_TYPES_WITH(IN_CTYPE, QUANTIZE_IMPL); \ + QUANTIZE_IMPL(IN_CTYPE, uint16_t, Bits16) \ default: \ ET_CHECK_MSG( \ false, \ @@ -329,6 +333,7 @@ Tensor& quantize_per_channel_out( case ScalarType::in_dtype: \ switch (out.scalar_type()) { \ ET_FORALL_INT_TYPES_WITH(CTYPE_IN, QUANTIZE_IMPL); \ + QUANTIZE_IMPL(CTYPE_IN, uint16_t, Bits16) \ default: \ ET_CHECK_MSG( \ false, \ diff --git a/kernels/quantized/test/op_dequantize_test.cpp b/kernels/quantized/test/op_dequantize_test.cpp index 8aa17772baf..10126264450 100644 --- a/kernels/quantized/test/op_dequantize_test.cpp +++ b/kernels/quantized/test/op_dequantize_test.cpp @@ -60,6 +60,10 @@ void test_dtype() { TEST(OpDequantizeOutTest, AllDtypesSupported) { et_pal_init(); test_dtype(); + test_dtype(); + test_dtype(); + test_dtype(); + test_dtype(); } TEST(OpDequantizeOutTest, NonWholeNumbers) { diff --git a/kernels/quantized/test/op_quantize_test.cpp b/kernels/quantized/test/op_quantize_test.cpp index be79fe09980..ce81186099b 100644 --- a/kernels/quantized/test/op_quantize_test.cpp +++ b/kernels/quantized/test/op_quantize_test.cpp @@ -35,14 +35,14 @@ void test_dtype() { Tensor input = tf.full({3, 5}, 4); double scale = 0.5; - int64_t zero_point = 127; + int64_t zero_point = 108; int64_t quant_min = 0; - int64_t quant_max = 255; + int64_t quant_max = 127; TensorFactory tfo; Tensor out = tfo.zeros({3, 5}); // 4 / 0.5 + 127 - Tensor expected = tfo.full({3, 5}, 135); + Tensor expected = tfo.full({3, 5}, 116); quantize_per_tensor_out( input, scale, zero_point, quant_min, quant_max, DTYPE, out); @@ -51,6 +51,10 @@ void test_dtype() { TEST(OpQuantizeOutTest, AllDtypesSupported) { test_dtype(); + test_dtype(); + test_dtype(); + test_dtype(); + test_dtype(); } TEST(OpQuantizeOutTest, TensorArgOverload) { diff --git a/runtime/core/exec_aten/testing_util/tensor_factory.h b/runtime/core/exec_aten/testing_util/tensor_factory.h index 3045af55819..9037b074a0d 100644 --- a/runtime/core/exec_aten/testing_util/tensor_factory.h +++ b/runtime/core/exec_aten/testing_util/tensor_factory.h @@ -645,6 +645,13 @@ struct ScalarTypeToCppTypeWrapper { using ctype = uint8_t; }; +// Use a C type of `uint16_t` instead of `Bits16` to simplify code reuse when +// testing multiple integer types. +template <> +struct ScalarTypeToCppTypeWrapper { + using ctype = uint16_t; +}; + // To allow implicit conversion between simple types to `ctype` #define SPECIALIZE_ScalarTypeToCppTypeWrapper(CTYPE, DTYPE) \ template <> \