diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 71246df8684..1ca6735b803 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -154,6 +154,10 @@ kernels: - arg_meta: null kernel_name: impl::reference::quantized_layer_norm_out +- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_layer_norm_per_tensor_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 627f211557c..84c07be78ca 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -125,6 +125,10 @@ kernels: - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_layer_norm_out +- func: cadence::quantized_layer_norm.per_tensor_out(Tensor input, float in_scale, int in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_layer_norm_per_tensor_out - func: cadence::quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index e860a2bfcc6..d47ea3f21af 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -36,6 +36,12 @@ lib.define( "quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" ) +lib.define( + "quantized_layer_norm.per_tensor(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)" +) +lib.define( + "quantized_layer_norm.per_tensor_out(Tensor X, float X_scale, int X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)" +) lib.define( "quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)" @@ -180,6 +186,21 @@ def quantized_layer_norm_meta( return input.new_empty(input.size(), dtype=input.dtype) +@register_fake("cadence::quantized_layer_norm.per_tensor") +def quantized_layer_norm_per_tensor_meta( + input: torch.Tensor, + X_scale: float, + X_zero_point: int, + normalized_shape: int, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + output_scale: float, + output_zero_point: int, +) -> torch.Tensor: + return input.new_empty(input.size(), dtype=input.dtype) + + @register_fake("cadence::quantized_relu") def quantized_relu_meta( X: torch.Tensor, diff --git a/backends/cadence/hifi/operators/quantized_layer_norm.cpp b/backends/cadence/hifi/operators/quantized_layer_norm.cpp index bdde5dfb22b..0606fb28523 100644 --- a/backends/cadence/hifi/operators/quantized_layer_norm.cpp +++ b/backends/cadence/hifi/operators/quantized_layer_norm.cpp @@ -27,7 +27,7 @@ namespace native { // Compute quantized layer_norm. The current implementation assumes that the // input is per-tensor quantized. template -void quantized_layer_norm_( +void quantized_layer_norm_per_tensor_( const Tensor& input, float input_scale, int64_t input_zero_point, @@ -107,7 +107,7 @@ void quantized_layer_norm_( int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; // Call other overload - quantized_layer_norm_( + quantized_layer_norm_per_tensor_( input, input_scale, input_zero_point, @@ -120,7 +120,7 @@ void quantized_layer_norm_( } void quantized_layer_norm_out( - KernelRuntimeContext& ctx, + __ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, const Tensor& in_scale, const Tensor& in_zero_point, @@ -157,6 +157,44 @@ void quantized_layer_norm_out( #undef typed_quantized_layer_norm } +void quantized_layer_norm_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { +#define typed_quantized_layer_norm(ctype, dtype) \ + case ScalarType::dtype: { \ + quantized_layer_norm_per_tensor_( \ + input, \ + in_scale, \ + in_zero_point, \ + weight, \ + bias, \ + eps, \ + output_scale, \ + output_zero_point, \ + out); \ + break; \ + } + + ScalarType dtype = input.scalar_type(); + switch (dtype) { + ET_FORALL_CADENCE_QUANTIZED_TYPES(typed_quantized_layer_norm) + default: + ET_DCHECK_MSG( + false, "Unhandled dtype %s", torch::executor::toString(dtype)); + } + +#undef typed_quantized_layer_norm +} + }; // namespace native }; // namespace HiFi }; // namespace impl diff --git a/backends/cadence/reference/operators/quantized_layer_norm.cpp b/backends/cadence/reference/operators/quantized_layer_norm.cpp index 92b1edf3dde..efddecc3f23 100644 --- a/backends/cadence/reference/operators/quantized_layer_norm.cpp +++ b/backends/cadence/reference/operators/quantized_layer_norm.cpp @@ -11,9 +11,11 @@ #include -using executorch::aten::Tensor; -using executorch::runtime::getLeadingDims; -using executorch::runtime::KernelRuntimeContext; +using ::executorch::aten::IntArrayRef; +using ::executorch::aten::ScalarType; +using ::executorch::aten::Tensor; +using ::executorch::runtime::getLeadingDims; +using ::executorch::runtime::KernelRuntimeContext; namespace impl { namespace reference { @@ -22,7 +24,7 @@ namespace native { // Compute quantized layer_norm. The current implementation assumes that the // input is per-tensor quantized. template -void quantized_layer_norm_( +void quantized_layer_norm_per_tensor_( const Tensor& input, double input_scale, int64_t input_zero_point, @@ -98,7 +100,7 @@ void quantized_layer_norm_( int64_t input_zero_point = in_zero_point.const_data_ptr()[0]; // Call other overload - quantized_layer_norm_( + quantized_layer_norm_per_tensor_( input, input_scale, input_zero_point, @@ -111,11 +113,11 @@ void quantized_layer_norm_( } void quantized_layer_norm_out( - KernelRuntimeContext& ctx, + __ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input, const Tensor& in_scale, const Tensor& in_zero_point, - const executorch::aten::IntArrayRef normalized_shape, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, const Tensor& weight, const Tensor& bias, double eps, @@ -152,6 +154,48 @@ void quantized_layer_norm_out( } } +void quantized_layer_norm_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + double in_scale, + int64_t in_zero_point, + __ET_UNUSED const executorch::aten::IntArrayRef normalized_shape, + const Tensor& weight, + const Tensor& bias, + double eps, + double output_scale, + int64_t output_zero_point, + Tensor& out) { + if (input.scalar_type() == executorch::aten::ScalarType::Byte) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else if (input.scalar_type() == executorch::aten::ScalarType::Char) { + quantized_layer_norm_per_tensor_( + input, + in_scale, + in_zero_point, + weight, + bias, + eps, + output_scale, + output_zero_point, + out); + } else { + ET_CHECK_MSG( + false, + "Unhandled input dtype %hhd", + static_cast(input.scalar_type())); + } +} + }; // namespace native }; // namespace reference }; // namespace impl