diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 68146760d9b..c43aa5ba4e9 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -229,6 +229,16 @@ - arg_meta: null kernel_name: impl::reference::quantized_linear_per_tensor_out +- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out + - func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 7a9000b530b..a706d251bd2 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -314,6 +314,16 @@ - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out +- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out + - func: cadence::quantized_relu_per_tensor.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 91ed3560a04..542d1fb2a30 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -56,10 +56,26 @@ lib.define( "quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, " "SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor" ) +lib.define( + "quantized_linear_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) +lib.define( + "quantized_linear_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, " + "int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" +) lib.define( "quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)" @@ -446,6 +462,50 @@ def quantized_linear_per_tensor_meta( return src.new_empty(out_size, dtype=src.dtype) +@register_fake("cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor") +def quantized_linear_asym8sxasym8s_asym8s_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: int, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + +@register_fake("cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor") +def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta( + src: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + in_zero_point: int, + weight_zero_point: int, + out_multiplier: int, + out_shift: int, + out_zero_point: int, + offset: Optional[torch.Tensor], +) -> torch.Tensor: + # src comes in shape [leading_dims, in_dim] + # weight comes in shape [out_dim, in_dim] + # output comes in empty with shape [leading_dims, out_dim] + out_size = list(src.size()) + weight_size = list(weight.size()) + assert len(weight_size) == 2 + out_size[-1] = weight_size[0] + return src.new_empty(out_size, dtype=src.dtype) + + @register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index f29a13a5bf8..29ddfb1ed53 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -18,7 +18,7 @@ class TestTypeDispatchPasses(unittest.TestCase): - def test_int8_dispatch(self) -> None: + def test_int8_dispatch_quantized_fully_connected(self) -> None: """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant""" x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) @@ -44,7 +44,7 @@ def test_int8_dispatch(self) -> None: 1, ) - def test_uint8_dispatch(self) -> None: + def test_uint8_dispatch_quantized_fully_connected(self) -> None: """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant""" x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) @@ -70,6 +70,58 @@ def test_uint8_dispatch(self) -> None: 1, ) + def test_int8_dispatch_quantized_linear(self) -> None: + """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_linear""" + x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + w = torch.randint(-128, 127, (4, 3), dtype=torch.int8) + b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_linear.per_tensor, + args=(x, w, b, 0, 0, 1, 0, 0, None), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), + 0, + ) + # Should be replaced with int8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_quantized_linear_dispatch(self) -> None: + """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_linear""" + x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + w = torch.randint(0, 255, (4, 3), dtype=torch.uint8) + b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_linear.per_tensor, + args=(x, w, b, 0, 0, 1, 0, 0, None), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor), + 0, + ) + # Should be replaced with uint8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor, + ), + 1, + ) + def test_mixed_types_error(self) -> None: """Test mixed int8/uint8 inputs should raise RuntimeError""" x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 431fcd4a0f2..ae30fe01086 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -23,6 +23,16 @@ class CompileTimeTypeDispatchPass(ExportPass): Replaces generic ops with ops that have explicit types. """ + _TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = { + (torch.int8, torch.int8): "asym8sxasym8s_asym8s", + (torch.uint8, torch.uint8): "asym8uxasym8u_asym8u", + } + + _SUPPORTED_OPS: dict[OpOverload, str] = { + exir_ops.edge.cadence.quantized_fully_connected.per_tensor: "quantized_fully_connected", + exir_ops.edge.cadence.quantized_linear.per_tensor: "quantized_linear", + } + def call_operator( self, op: OpOverload, @@ -30,33 +40,23 @@ def call_operator( kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op not in { - exir_ops.edge.cadence.quantized_fully_connected.per_tensor, - }: + if op not in self._SUPPORTED_OPS: return super().call_operator(op, args, kwargs, meta) - if ( - # pyre-ignore[16]: None has no attribute `to_tensor`. - args[0].to_tensor().dtype == torch.int8 - and args[1].to_tensor().dtype == torch.int8 - ): - return super().call_operator( - exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor, - args, - kwargs, - meta, - ) - elif ( - args[0].to_tensor().dtype == torch.uint8 - and args[1].to_tensor().dtype == torch.uint8 - ): - return super().call_operator( - exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor, - args, - kwargs, - meta, - ) - else: + # pyre-ignore[16]: None has no attribute `to_tensor`. + input_dtype = args[0].to_tensor().dtype + weight_dtype = args[1].to_tensor().dtype + dtype_pair = (input_dtype, weight_dtype) + + if dtype_pair not in self._TYPE_DISPATCH_MAP: raise RuntimeError( - f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}" + f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}" ) + + base_op_name = self._SUPPORTED_OPS[op] + type_suffix = self._TYPE_DISPATCH_MAP[dtype_pair] + + typed_op_name = f"{base_op_name}_{type_suffix}" + typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor + + return super().call_operator(typed_op, args, kwargs, meta) diff --git a/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..7b8ab8e91b9 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_linear_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; +using std::optional; + +void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = getLeadingDims(in, in.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const int8_t* __restrict__ in_data = in.const_data_ptr(); + const int8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + const int32_t out_multipler_int32 = static_cast(out_multiplier); + const int32_t out_shift_int32 = static_cast(out_shift); + + // The nnlib kernel to compute quantized linear via matmul. + const int32_t ret = xa_nn_matmul_asym8sxasym8s_asym8s( + out_data, // p_out + weight_data, // p_mat1, + in_data, // p_mat2, + bias_data, // p_bias + out_dim, // rows of p_mat1 + in_dim, // cols of p_mat1 + in_dim, // row_stride of p_mat1 + leading_dims, // vec_count, i.e., rows of p_mat2 + in_dim, // vec_offset of p_mat2. + out_dim, // out_offset, i.e., offset of next output element written + 1, // out_stride, i.e., stride to go to next output row + -weight_zero_point, // mat1_zero_bias + -in_zero_point, // mat2_zero_bias + out_multipler_int32, // out_multiplier + out_shift_int32, // out_shift + out_zero_point); // out_zero_bias + ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear_per_tensor failed"); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..e9632e77eeb --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_linear_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; +using std::optional; + +void quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + const int64_t leading_dims = getLeadingDims(in, in.dim() - 1); + const int64_t out_dim = weight.size(0); // = out_dim + const int64_t in_dim = weight.size(1); // = in_dim + + const uint8_t* __restrict__ in_data = in.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + const int32_t out_multipler_int32 = static_cast(out_multiplier); + const int32_t out_shift_int32 = static_cast(out_shift); + + // The nnlib kernel to compute quantized linear via matmul. + const int32_t ret = xa_nn_matmul_asym8uxasym8u_asym8u( + out_data, // p_out + weight_data, // p_mat1, + in_data, // p_mat2, + bias_data, // p_bias + out_dim, // rows of p_mat1 + in_dim, // cols of p_mat1 + in_dim, // row_stride of p_mat1 + leading_dims, // vec_count, i.e., rows of p_mat2 + in_dim, // vec_offset of p_mat2. + out_dim, // out_offset, i.e., offset of next output element written + 1, // out_stride, i.e., stride to go to next output row + -weight_zero_point, // mat1_zero_bias + -in_zero_point, // mat2_zero_bias + out_multipler_int32, // out_multiplier + out_shift_int32, // out_shift + out_zero_point); // out_zero_bias + ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear_per_tensor failed"); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 9a797874cef..f8f25443e09 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -69,6 +69,8 @@ OPERATORS = [ "quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out", "quantized_layer_norm", "quantized_linear_out", + "quantized_linear_asym8sxasym8s_asym8s_per_tensor_out", + "quantized_linear_asym8uxasym8u_asym8u_per_tensor_out", "quantized_matmul_out", "quantized_relu_out", "quantize_per_tensor", diff --git a/backends/cadence/reference/operators/quantized_linear_out.cpp b/backends/cadence/reference/operators/quantized_linear_out.cpp index edd8634d56e..f60c98e5875 100644 --- a/backends/cadence/reference/operators/quantized_linear_out.cpp +++ b/backends/cadence/reference/operators/quantized_linear_out.cpp @@ -154,6 +154,80 @@ void quantized_linear_per_tensor_out( #undef typed_quantized_linear_per_tensor } +void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + __ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + executorch::aten::ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor +} + +void quantized_linear_asym8uxasym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& src, + const Tensor& weight, + const Tensor& bias, + const int64_t src_zero_point, + const int64_t weight_zero_point, + const int64_t out_multiplier, + const int64_t out_shift, + const int64_t out_zero_point, + __ET_UNUSED const std::optional& offset, + Tensor& out) { +#define typed_quantized_linear_per_tensor(ctype, dtype) \ + case executorch::aten::ScalarType::dtype: { \ + quantized_linear_per_tensor_( \ + src, \ + weight, \ + bias, \ + src_zero_point, \ + weight_zero_point, \ + out_multiplier, \ + out_shift, \ + out_zero_point, \ + out); \ + break; \ + } + + executorch::aten::ScalarType dtype = out.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_linear_per_tensor); + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", executorch::runtime::toString(dtype)); + } +#undef typed_quantized_linear_per_tensor +} + }; // namespace native }; // namespace reference }; // namespace impl