From ee18354b60f3404c5c551f6e6d782cf6aef79f0e Mon Sep 17 00:00:00 2001 From: Riley Dulin Date: Fri, 15 Nov 2024 11:27:54 -0800 Subject: [PATCH] Add UInt16 support to Cadence kernels (#6893) Summary: In preparation for using uint16 on Cadence, add support to the quant kernels. Same as: https://github.com/pytorch/executorch/pull/6724, but added UInt16 as well as Bits16. Reviewed By: JacobSzwejbka, zonglinpeng Differential Revision: D66016288 --- backends/cadence/hifi/operators/dequantize_per_tensor.cpp | 4 +++- backends/cadence/hifi/operators/quantize_per_tensor.cpp | 4 +++- .../cadence/reference/operators/dequantize_per_tensor.cpp | 4 +++- backends/cadence/reference/operators/quantize_per_tensor.cpp | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp index 996d753c59d..243eeef2e1d 100644 --- a/backends/cadence/hifi/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/dequantize_per_tensor.cpp @@ -41,7 +41,9 @@ void dequantize_per_tensor_out( } else if (input.scalar_type() == ScalarType::Short) { const int16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); - } else if (input.scalar_type() == ScalarType::Bits16) { + } else if ( + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); dequantize(out_data, input_data, scale, zero_point, numel); } else if (input.scalar_type() == ScalarType::Int) { diff --git a/backends/cadence/hifi/operators/quantize_per_tensor.cpp b/backends/cadence/hifi/operators/quantize_per_tensor.cpp index 1078b5716c1..4f59ef0ea8a 100644 --- a/backends/cadence/hifi/operators/quantize_per_tensor.cpp +++ b/backends/cadence/hifi/operators/quantize_per_tensor.cpp @@ -44,7 +44,9 @@ void quantize_per_tensor_out( int16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); - } else if (out.scalar_type() == ScalarType::Bits16) { + } else if ( + out.scalar_type() == ScalarType::Bits16 || + out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); cadence::impl::HiFi::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); diff --git a/backends/cadence/reference/operators/dequantize_per_tensor.cpp b/backends/cadence/reference/operators/dequantize_per_tensor.cpp index b49c045b94f..1f48fc78201 100644 --- a/backends/cadence/reference/operators/dequantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/dequantize_per_tensor.cpp @@ -37,7 +37,9 @@ void dequantize_per_tensor_out( const int8_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); - } else if (input.scalar_type() == ScalarType::Bits16) { + } else if ( + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::UInt16) { const uint16_t* input_data = input.const_data_ptr(); impl::reference::kernels::dequantize( out_data, input_data, scale, zero_point, numel); diff --git a/backends/cadence/reference/operators/quantize_per_tensor.cpp b/backends/cadence/reference/operators/quantize_per_tensor.cpp index ad5fa791b51..9debf37364e 100644 --- a/backends/cadence/reference/operators/quantize_per_tensor.cpp +++ b/backends/cadence/reference/operators/quantize_per_tensor.cpp @@ -39,7 +39,9 @@ void quantize_per_tensor_out( int8_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel); - } else if (out.scalar_type() == ScalarType::Bits16) { + } else if ( + out.scalar_type() == ScalarType::Bits16 || + out.scalar_type() == ScalarType::UInt16) { uint16_t* out_data = out.mutable_data_ptr(); impl::reference::kernels::quantize( out_data, input_data, 1. / scale, zero_point, numel);