diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 6891dd52c6b..3968f215602 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -249,6 +249,21 @@ - arg_meta: null kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out +- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_add_per_tensor_out + +- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_add_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_add_asym8uxasym8u_asym8u_per_tensor_out + - func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 7e6bfaadcc7..19249ef50a5 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -404,6 +404,16 @@ - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out +- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_add_asym8sxasym8s_asym8s_per_tensor_out + +- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_add_asym8uxasym8u_asym8u_per_tensor_out + - func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index a98fedd22ea..52b688490b2 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -325,6 +325,22 @@ "quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " "int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_add_asym8sxasym8s_asym8s.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point) -> Tensor" +) +lib.define( + "quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_add_asym8uxasym8u_asym8u.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point) -> Tensor" +) +lib.define( + "quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, " + "int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, " "Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)" @@ -503,6 +519,36 @@ def quantized_add_per_tensor_meta( return X.new_empty(out_size, dtype=X.dtype) +@register_fake("cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor") +def quantized_add_asym8sxasym8s_asym8s_per_tensor_meta( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + out_size = torch.broadcast_shapes(X.size(), Y.size()) + return X.new_empty(out_size, dtype=X.dtype) + + +@register_fake("cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor") +def quantized_add_asym8uxasym8u_asym8u_per_tensor_meta( + X: torch.Tensor, + X_scale: float, + X_zero_point: int, + Y: torch.Tensor, + Y_scale: float, + Y_zero_point: int, + out_scale: float, + out_zero_point: int, +) -> torch.Tensor: + out_size = torch.broadcast_shapes(X.size(), Y.size()) + return X.new_empty(out_size, dtype=X.dtype) + + @register_fake("cadence::quantized_linear") def quantized_linear_meta( src: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index 2b12a188cf6..1deebdfbb1c 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -445,3 +445,53 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None: ), 1, ) + + def test_int8_dispatch_quantized_add(self) -> None: + """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add""" + x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + y = torch.randint(-128, 127, (2, 3), dtype=torch.int8) + gm = single_op_builder( + placeholders=(x, y), + op=exir_ops.edge.cadence.quantized_add.per_tensor, + args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), + 0, + ) + # Should be replaced with int8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch_quantized_add(self) -> None: + """Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_add""" + x = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + y = torch.randint(0, 255, (2, 3), dtype=torch.uint8) + gm = single_op_builder( + placeholders=(x, y), + op=exir_ops.edge.cadence.quantized_add.per_tensor, + args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor), + 0, + ) + # Should be replaced with uint8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor, + ), + 1, + ) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index a0443b69b9b..c53f62a45b7 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -85,6 +85,14 @@ class CompileTimeTypeDispatchPass(ExportPass): (torch.uint8,): "asym8u_asym8u", }, ), + exir_ops.edge.cadence.quantized_add.per_tensor: OpConfig( + "quantized_add", + type_dispatch_suffixes={ + (torch.int8, torch.int8): "asym8sxasym8s_asym8s", + (torch.uint8, torch.uint8): "asym8uxasym8u_asym8u", + }, + weight_arg_idx=3, + ), } def call_operator( diff --git a/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..fa84a877c56 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_add_asym8sxasym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + const int8_t* __restrict__ X_data = X.const_data_ptr(); + const int8_t* __restrict__ Y_data = Y.const_data_ptr(); + int8_t* __restrict__ out_data = out.mutable_data_ptr(); + + ssize_t Y_numel = Y.numel(); + ssize_t X_numel = X.numel(); + ssize_t out_numel = out.numel(); + + float X_scale_f = static_cast(X_scale); + float Y_scale_f = static_cast(Y_scale); + float out_scale_f = static_cast(out_scale); + int32_t X_zero_point_i32 = static_cast(X_zero_point); + int32_t Y_zero_point_i32 = static_cast(Y_zero_point); + int32_t out_zero_point_i32 = static_cast(out_zero_point); + + float inv_out_scale = 1.0f / out_scale_f; + constexpr float min_val = + static_cast(std::numeric_limits::min()); + constexpr float max_val = + static_cast(std::numeric_limits::max()); + + /* Tensor X exactly matches Y in shape, no broadcasting */ + if (X_numel == Y_numel && Y_numel == out_numel) { + for (size_t i = 0; i < X_numel; ++i) { + float x = X_scale_f * (X_data[i] - X_zero_point_i32); + float y = Y_scale_f * (Y_data[i] - Y_zero_point_i32); + float z = x + y; + float tmp = roundf(z * inv_out_scale + out_zero_point_i32); + out_data[i] = + static_cast(std::max(std::min(tmp, max_val), min_val)); + } + } /* if Y is a scalar Tensor */ + else if (Y_numel == 1) { + float y = + kernels::dequantize(Y_data[0], Y_scale_f, Y_zero_point_i32); + for (size_t i = 0; i < X_numel; ++i) { + float x = + kernels::dequantize(X_data[i], X_scale_f, X_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } /* if X is a scalar Tensor */ + else if (X_numel == 1) { + float x = + kernels::dequantize(X_data[0], X_scale_f, X_zero_point_i32); + for (size_t i = 0; i < Y_numel; ++i) { + float y = + kernels::dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } /* other broadcasting cases */ + else { + /* Broadcasting implementation */ + ssize_t X_dim = X.dim(); + ssize_t Y_dim = Y.dim(); + ssize_t out_dim = out.dim(); + + /* Precompute strides for X and Y tensors */ + constexpr size_t max_dim = executorch::runtime::kTensorDimensionLimit; + size_t X_strides[max_dim] = {0}; + size_t Y_strides[max_dim] = {0}; + size_t X_stride_val = 1; + size_t Y_stride_val = 1; + + /* Calculate strides from last dimension to first */ + for (int d = out_dim - 1; d >= 0 && d >= out_dim - max_dim; --d) { + int idx = out_dim - 1 - d; /* Index into the fixed-size array */ + if (d >= out_dim - X_dim) { + size_t X_d = d - (out_dim - X_dim); + X_strides[idx] = X_stride_val; + X_stride_val *= X.size(X_d); + } + + if (d >= out_dim - Y_dim) { + size_t Y_d = d - (out_dim - Y_dim); + Y_strides[idx] = Y_stride_val; + Y_stride_val *= Y.size(Y_d); + } + } + + /* Iterate over output tensor */ + for (ssize_t i = 0; i < out_numel; ++i) { + size_t out_idx = i; + size_t X_idx = 0; + size_t Y_idx = 0; + + /* Compute corresponding indices in input tensors */ + for (int d = out_dim - 1; d >= 0; --d) { + size_t out_dim_idx = out_idx % out.size(d); + out_idx /= out.size(d); + + /* Compute X index */ + if (d >= out_dim - X_dim) { + size_t X_d = d - (out_dim - X_dim); + size_t X_dim_idx = out_dim_idx % X.size(X_d); + if (d >= out_dim - max_dim) { + int idx = out_dim - 1 - d; + X_idx += X_dim_idx * X_strides[idx]; + } else { + size_t X_stride = 1; + for (int k = out_dim - 1; k > d; --k) { + if (k >= out_dim - X_dim) { + size_t X_k = k - (out_dim - X_dim); + X_stride *= X.size(X_k); + } + } + X_idx += X_dim_idx * X_stride; + } + } + + /* Compute Y index */ + if (d >= out_dim - Y_dim) { + size_t Y_d = d - (out_dim - Y_dim); + size_t Y_dim_idx = out_dim_idx % Y.size(Y_d); + if (d >= out_dim - max_dim) { + int idx = out_dim - 1 - d; + Y_idx += Y_dim_idx * Y_strides[idx]; + } else { + size_t Y_stride = 1; + for (int k = out_dim - 1; k > d; --k) { + if (k >= out_dim - Y_dim) { + size_t Y_k = k - (out_dim - Y_dim); + Y_stride *= Y.size(Y_k); + } + } + Y_idx += Y_dim_idx * Y_stride; + } + } + } + + /* Apply the operation */ + float x = kernels::dequantize( + X_data[X_idx], X_scale_f, X_zero_point_i32); + float y = kernels::dequantize( + Y_data[Y_idx], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..b7c453dda2b --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_add_asym8uxasym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,179 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +using ::executorch::aten::Tensor; +using ::executorch::runtime::KernelRuntimeContext; + +void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + const uint8_t* __restrict__ X_data = X.const_data_ptr(); + const uint8_t* __restrict__ Y_data = Y.const_data_ptr(); + uint8_t* __restrict__ out_data = out.mutable_data_ptr(); + + ssize_t Y_numel = Y.numel(); + ssize_t X_numel = X.numel(); + ssize_t out_numel = out.numel(); + + float X_scale_f = static_cast(X_scale); + float Y_scale_f = static_cast(Y_scale); + float out_scale_f = static_cast(out_scale); + int32_t X_zero_point_i32 = static_cast(X_zero_point); + int32_t Y_zero_point_i32 = static_cast(Y_zero_point); + int32_t out_zero_point_i32 = static_cast(out_zero_point); + + float inv_out_scale = 1.0f / out_scale_f; + constexpr float min_val = + static_cast(std::numeric_limits::min()); + constexpr float max_val = + static_cast(std::numeric_limits::max()); + + /* Tensor X exactly matches Y in shape, no broadcasting */ + if (X_numel == Y_numel && Y_numel == out_numel) { + for (size_t i = 0; i < X_numel; ++i) { + float x = X_scale_f * (X_data[i] - X_zero_point_i32); + float y = Y_scale_f * (Y_data[i] - Y_zero_point_i32); + float z = x + y; + float tmp = roundf(z * inv_out_scale + out_zero_point_i32); + out_data[i] = + static_cast(std::max(std::min(tmp, max_val), min_val)); + } + } /* if Y is a scalar Tensor */ + else if (Y_numel == 1) { + float y = + kernels::dequantize(Y_data[0], Y_scale_f, Y_zero_point_i32); + for (size_t i = 0; i < X_numel; ++i) { + float x = + kernels::dequantize(X_data[i], X_scale_f, X_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } /* if X is a scalar Tensor */ + else if (X_numel == 1) { + float x = + kernels::dequantize(X_data[0], X_scale_f, X_zero_point_i32); + for (size_t i = 0; i < Y_numel; ++i) { + float y = + kernels::dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } /* other broadcasting cases */ + else { + /* Broadcasting implementation */ + ssize_t X_dim = X.dim(); + ssize_t Y_dim = Y.dim(); + ssize_t out_dim = out.dim(); + + /* Precompute strides for X and Y tensors */ + constexpr size_t max_dim = executorch::runtime::kTensorDimensionLimit; + size_t X_strides[max_dim] = {0}; + size_t Y_strides[max_dim] = {0}; + size_t X_stride_val = 1; + size_t Y_stride_val = 1; + + /* Calculate strides from last dimension to first */ + for (int d = out_dim - 1; d >= 0 && d >= out_dim - max_dim; --d) { + int idx = out_dim - 1 - d; /* Index into the fixed-size array */ + if (d >= out_dim - X_dim) { + size_t X_d = d - (out_dim - X_dim); + X_strides[idx] = X_stride_val; + X_stride_val *= X.size(X_d); + } + + if (d >= out_dim - Y_dim) { + size_t Y_d = d - (out_dim - Y_dim); + Y_strides[idx] = Y_stride_val; + Y_stride_val *= Y.size(Y_d); + } + } + + /* Iterate over output tensor */ + for (ssize_t i = 0; i < out_numel; ++i) { + size_t out_idx = i; + size_t X_idx = 0; + size_t Y_idx = 0; + + /* Compute corresponding indices in input tensors */ + for (int d = out_dim - 1; d >= 0; --d) { + size_t out_dim_idx = out_idx % out.size(d); + out_idx /= out.size(d); + + /* Compute X index */ + if (d >= out_dim - X_dim) { + size_t X_d = d - (out_dim - X_dim); + size_t X_dim_idx = out_dim_idx % X.size(X_d); + if (d >= out_dim - max_dim) { + int idx = out_dim - 1 - d; + X_idx += X_dim_idx * X_strides[idx]; + } else { + size_t X_stride = 1; + for (int k = out_dim - 1; k > d; --k) { + if (k >= out_dim - X_dim) { + size_t X_k = k - (out_dim - X_dim); + X_stride *= X.size(X_k); + } + } + X_idx += X_dim_idx * X_stride; + } + } + + /* Compute Y index */ + if (d >= out_dim - Y_dim) { + size_t Y_d = d - (out_dim - Y_dim); + size_t Y_dim_idx = out_dim_idx % Y.size(Y_d); + if (d >= out_dim - max_dim) { + int idx = out_dim - 1 - d; + Y_idx += Y_dim_idx * Y_strides[idx]; + } else { + size_t Y_stride = 1; + for (int k = out_dim - 1; k > d; --k) { + if (k >= out_dim - Y_dim) { + size_t Y_k = k - (out_dim - Y_dim); + Y_stride *= Y.size(Y_k); + } + } + Y_idx += Y_dim_idx * Y_stride; + } + } + } + + /* Apply the operation */ + float x = kernels::dequantize( + X_data[X_idx], X_scale_f, X_zero_point_i32); + float y = kernels::dequantize( + Y_data[Y_idx], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = + kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/operators.h b/backends/cadence/hifi/operators/operators.h index c30242c144b..5b8a1e253c1 100644 --- a/backends/cadence/hifi/operators/operators.h +++ b/backends/cadence/hifi/operators/operators.h @@ -168,6 +168,30 @@ ::executorch::aten::Tensor& permute_copy_out( ::executorch::aten::IntArrayRef dims, ::executorch::aten::Tensor& out); +void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + +void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + ::executorch::runtime::KernelRuntimeContext& ctx, + const ::executorch::aten::Tensor& X, + double X_scale, + int64_t X_zero_point, + const ::executorch::aten::Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + ::executorch::aten::Tensor& out); + } // namespace native } // namespace HiFi } // namespace impl diff --git a/backends/cadence/reference/operators/quantized_add_out.cpp b/backends/cadence/reference/operators/quantized_add_out.cpp new file mode 100644 index 00000000000..2a33f69632a --- /dev/null +++ b/backends/cadence/reference/operators/quantized_add_out.cpp @@ -0,0 +1,192 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +namespace impl { +namespace reference { +namespace native { + +using executorch::aten::Tensor; +using executorch::runtime::KernelRuntimeContext; + +template +void quantized_add_per_tensor_impl( + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + const T* __restrict__ X_data = X.const_data_ptr(); + const T* __restrict__ Y_data = Y.const_data_ptr(); + T* __restrict__ out_data = out.mutable_data_ptr(); + + ssize_t Y_numel = Y.numel(); + ssize_t X_numel = X.numel(); + ssize_t out_numel = out.numel(); + + float X_scale_f = static_cast(X_scale); + float Y_scale_f = static_cast(Y_scale); + float out_scale_f = static_cast(out_scale); + int32_t X_zero_point_i32 = static_cast(X_zero_point); + int32_t Y_zero_point_i32 = static_cast(Y_zero_point); + int32_t out_zero_point_i32 = static_cast(out_zero_point); + + float inv_out_scale = 1.0f / out_scale_f; + + // Simple case: tensors have the same shape, no broadcasting + if (X_numel == Y_numel && Y_numel == out_numel) { + for (size_t i = 0; i < X_numel; ++i) { + float x = kernels::dequantize(X_data[i], X_scale_f, X_zero_point_i32); + float y = kernels::dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } + // Y is a scalar tensor + else if (Y_numel == 1) { + float y = kernels::dequantize(Y_data[0], Y_scale_f, Y_zero_point_i32); + for (size_t i = 0; i < X_numel; ++i) { + float x = kernels::dequantize(X_data[i], X_scale_f, X_zero_point_i32); + float z = x + y; + out_data[i] = kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } + // X is a scalar tensor + else if (X_numel == 1) { + float x = kernels::dequantize(X_data[0], X_scale_f, X_zero_point_i32); + for (size_t i = 0; i < Y_numel; ++i) { + float y = kernels::dequantize(Y_data[i], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } + // General broadcasting case - simplified implementation + else { + for (ssize_t i = 0; i < out_numel; ++i) { + // Simple broadcasting: repeat elements as needed + size_t x_idx = (X_numel == 1) ? 0 : i % X_numel; + size_t y_idx = (Y_numel == 1) ? 0 : i % Y_numel; + + float x = + kernels::dequantize(X_data[x_idx], X_scale_f, X_zero_point_i32); + float y = + kernels::dequantize(Y_data[y_idx], Y_scale_f, Y_zero_point_i32); + float z = x + y; + out_data[i] = kernels::quantize(z, inv_out_scale, out_zero_point_i32); + } + } +} + +// Generic quantized add with type dispatch +void quantized_add_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + (void)ctx; + + executorch::aten::ScalarType dtype = X.scalar_type(); + switch (dtype) { + case executorch::aten::ScalarType::Byte: + quantized_add_per_tensor_impl( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + out); + break; + case executorch::aten::ScalarType::Char: + quantized_add_per_tensor_impl( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + out); + break; + default: + ET_CHECK_MSG( + false, "Unhandled input dtype %hhd", static_cast(dtype)); + } +} + +// int8-specific quantized add +void quantized_add_asym8sxasym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + (void)ctx; + + quantized_add_per_tensor_impl( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + out); +} + +// uint8-specific quantized add +void quantized_add_asym8uxasym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& X, + double X_scale, + int64_t X_zero_point, + const Tensor& Y, + double Y_scale, + int64_t Y_zero_point, + double out_scale, + int64_t out_zero_point, + Tensor& out) { + (void)ctx; + + quantized_add_per_tensor_impl( + X, + X_scale, + X_zero_point, + Y, + Y_scale, + Y_zero_point, + out_scale, + out_zero_point, + out); +} + +} // namespace native +} // namespace reference +} // namespace impl