diff --git a/backends/cadence/hifi/operators/operators.h b/backends/cadence/hifi/operators/operators.h new file mode 100644 index 00000000000..8c47468b049 --- /dev/null +++ b/backends/cadence/hifi/operators/operators.h @@ -0,0 +1,7 @@ +// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +#pragma once + +#define ET_FORALL_CADENCE_QUANTIZED_TYPES(_) \ + _(uint8_t, Byte) \ + _(int8_t, Char) diff --git a/backends/cadence/hifi/operators/quantized_layer_norm.cpp b/backends/cadence/hifi/operators/quantized_layer_norm.cpp index 439bb594f50..bdde5dfb22b 100644 --- a/backends/cadence/hifi/operators/quantized_layer_norm.cpp +++ b/backends/cadence/hifi/operators/quantized_layer_norm.cpp @@ -7,14 +7,17 @@ */ #include +#include #include #include #include #include -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; namespace cadence { namespace impl { @@ -77,10 +80,10 @@ void quantized_layer_norm_( for (size_t j = 0; j < last_dim; ++j) { // Since X is quantized, we dequantize it, compute fp32 result, and // quantize the result to an int8/uint8 value. - float val = cadence::impl::HiFi::kernels::dequantize( + float val = ::cadence::impl::HiFi::kernels::dequantize( x[j], input_scale, input_zero_point); val = (val - mean) * inv_std * weight_data[j] + bias_data[j]; - y[j] = cadence::impl::HiFi::kernels::quantize( + y[j] = ::cadence::impl::HiFi::kernels::quantize( val, output_inv_scale, output_zero_point); } } @@ -121,38 +124,37 @@ void quantized_layer_norm_out( const Tensor& input, const Tensor& in_scale, const Tensor& in_zero_point, - const executorch::aten::IntArrayRef normalized_shape, + __ET_UNUSED const IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, double output_scale, int64_t output_zero_point, Tensor& out) { - if (input.scalar_type() == executorch::aten::ScalarType::Byte) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { - quantized_layer_norm_( - input, - in_scale, - in_zero_point, - weight, - bias, - eps, - output_scale, - output_zero_point, - out); - } else { - ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type()); +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_layer_norm } }; // namespace native diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index 1b6a405118a..96f063728c8 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): srcs = glob([ "*.cpp", ]), + exported_headers = glob(["*.h"]), platforms = CXX, deps = [ "//executorch/kernels/portable/cpu/util:all_deps",