diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 61cf07c4d5d..d79ac92e285 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -71,11 +71,11 @@ kernels: - arg_meta: null kernel_name: cadence::impl::HiFi::full_out - + - op: gt.Scalar_out kernels: - arg_meta: null - kernel_name: torch::executor::gt_scalar_out + kernel_name: torch::executor::gt_scalar_out - op: gelu.out kernels: @@ -100,7 +100,7 @@ - op: mean.out kernels: - arg_meta: null - kernel_name: cadence::impl::HiFi::mean_dim_out + kernel_name: cadence::impl::HiFi::mean_dim_out - op: minimum.out kernels: @@ -213,3 +213,13 @@ kernels: - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out + +- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_fully_connected_out + +- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out diff --git a/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp b/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp new file mode 100644 index 00000000000..a428c27d0ce --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_fully_connected_out.cpp @@ -0,0 +1,267 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. +#include +#include + +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::ArrayRef; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::optional; +using ::executorch::aten::Scalar; +using ::executorch::aten::ScalarType; +using ::executorch::aten::SizesType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void inline _quantized_fully_connected_asym8u( + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const uint8_t* __restrict__ in_data = in.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8uxasym8u_asym8u( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -weight_zero_point.const_data_ptr()[0], + out_multiplier.const_data_ptr()[0], + out_shift.const_data_ptr()[0], + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +void inline _quantized_fully_connected_asym8s( + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const int8_t* __restrict__ in_data = in.const_data_ptr(); + const int8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8sxasym8s_asym8s( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -weight_zero_point.const_data_ptr()[0], + out_multiplier.const_data_ptr()[0], + out_shift.const_data_ptr()[0], + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +void quantized_fully_connected_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + const Tensor& weight_zero_point, + const Tensor& out_multiplier, + const Tensor& out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + if (out.scalar_type() == ScalarType::Byte) { + _quantized_fully_connected_asym8u( + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else if (out.scalar_type() == ScalarType::Char) { + _quantized_fully_connected_asym8s( + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else { + ET_CHECK_MSG( + false, + "quantized fully connected only supported for uint8 and int8 dtypes"); + } +} + +void inline _quantized_fully_connected_per_tensor_asym8u( + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const uint8_t* __restrict__ in_data = in.const_data_ptr(); + const uint8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8uxasym8u_asym8u( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -static_cast(weight_zero_point), + static_cast(out_multiplier), + static_cast(out_shift), + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +void inline _quantized_fully_connected_per_tensor_asym8s( + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + // input comes in shape [leading_dims, in_dim] + // weight comes in shape [out_dim, in_dim] + // output comes in empty with shape [leading_dims, out_dim] + // Perform matrix multiply (M x N) x (N x P)' => M x P + int64_t leading_dims = 1; + int64_t out_dim = weight.size(0); // = out_dim + int64_t in_dim = weight.size(1); // = in_dim + + const int8_t* __restrict__ in_data = in.const_data_ptr(); + const int8_t* __restrict__ weight_data = weight.const_data_ptr(); + const int32_t* __restrict__ bias_data = bias.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + int32_t ret = xa_nn_fully_connected_asym8sxasym8s_asym8s( + out_data, + weight_data, + in_data, + bias_data, + in_dim, // weight_depth, number of columns in weight + out_dim, // out_depth, number of rows in weight + -in_zero_point, + -static_cast(weight_zero_point), + static_cast(out_multiplier), + static_cast(out_shift), + out_zero_point); + ET_DCHECK_MSG(ret == 0, "HiFi quantized::fully_connected failed"); +} + +void quantized_fully_connected_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& weight, + const Tensor& bias, + int64_t in_zero_point, + int64_t weight_zero_point, + int64_t out_multiplier, + int64_t out_shift, + int64_t out_zero_point, + __ET_UNUSED const optional& offset, + Tensor& out) { + if (out.scalar_type() == ScalarType::Byte) { + _quantized_fully_connected_per_tensor_asym8u( + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else if (out.scalar_type() == ScalarType::Char) { + _quantized_fully_connected_per_tensor_asym8s( + in, + weight, + bias, + in_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + offset, + out); + } else { + ET_CHECK_MSG( + false, + "quantized fully connected only supported for uint8 and int8 dtypes"); + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 1c2b481410d..c94b9a3b19e 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -45,6 +45,7 @@ OPERATORS = [ "mul", "permute_copy", "pow", + "quantized_fully_connected_out", "quantize_per_tensor", "quantized_layer_norm", "quantized_linear_out",