Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,21 @@
- arg_meta: null
kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_add_per_tensor_out

- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_add_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_add_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,16 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out

- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_add_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_add_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
46 changes: 46 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,22 @@
"quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_add_asym8sxasym8s_asym8s.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point) -> Tensor"
)
lib.define(
"quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_add_asym8uxasym8u_asym8u.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point) -> Tensor"
)
lib.define(
"quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
Expand Down Expand Up @@ -503,6 +519,36 @@ def quantized_add_per_tensor_meta(
return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor")
def quantized_add_asym8sxasym8s_asym8s_per_tensor_meta(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
out_size = torch.broadcast_shapes(X.size(), Y.size())
return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor")
def quantized_add_asym8uxasym8u_asym8u_per_tensor_meta(
X: torch.Tensor,
X_scale: float,
X_zero_point: int,
Y: torch.Tensor,
Y_scale: float,
Y_zero_point: int,
out_scale: float,
out_zero_point: int,
) -> torch.Tensor:
out_size = torch.broadcast_shapes(X.size(), Y.size())
return X.new_empty(out_size, dtype=X.dtype)


@register_fake("cadence::quantized_linear")
def quantized_linear_meta(
src: torch.Tensor,
Expand Down
50 changes: 50 additions & 0 deletions backends/cadence/aot/tests/test_type_dispatch_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,53 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
),
1,
)

def test_int8_dispatch_quantized_add(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
y = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
gm = single_op_builder(
placeholders=(x, y),
op=exir_ops.edge.cadence.quantized_add.per_tensor,
args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor),
0,
)
# Should be replaced with int8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor,
),
1,
)

def test_uint8_dispatch_quantized_add(self) -> None:
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_add"""
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
y = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
gm = single_op_builder(
placeholders=(x, y),
op=exir_ops.edge.cadence.quantized_add.per_tensor,
args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor),
0,
)
# Should be replaced with uint8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor,
),
1,
)
8 changes: 8 additions & 0 deletions backends/cadence/aot/type_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,14 @@ class CompileTimeTypeDispatchPass(ExportPass):
(torch.uint8,): "asym8u_asym8u",
},
),
exir_ops.edge.cadence.quantized_add.per_tensor: OpConfig(
"quantized_add",
type_dispatch_suffixes={
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
},
weight_arg_idx=3,
),
}

def call_operator(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/runtime/kernel/kernel_includes.h>

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {

using ::executorch::aten::Tensor;
using ::executorch::runtime::KernelRuntimeContext;

void quantized_add_asym8sxasym8s_asym8s_per_tensor_out(
KernelRuntimeContext& ctx,
const Tensor& X,
double X_scale,
int64_t X_zero_point,
const Tensor& Y,
double Y_scale,
int64_t Y_zero_point,
double out_scale,
int64_t out_zero_point,
Tensor& out) {
const int8_t* __restrict__ X_data = X.const_data_ptr<int8_t>();
const int8_t* __restrict__ Y_data = Y.const_data_ptr<int8_t>();
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();

ssize_t Y_numel = Y.numel();
ssize_t X_numel = X.numel();
ssize_t out_numel = out.numel();

float X_scale_f = static_cast<float>(X_scale);
float Y_scale_f = static_cast<float>(Y_scale);
float out_scale_f = static_cast<float>(out_scale);
int32_t X_zero_point_i32 = static_cast<int32_t>(X_zero_point);
int32_t Y_zero_point_i32 = static_cast<int32_t>(Y_zero_point);
int32_t out_zero_point_i32 = static_cast<int32_t>(out_zero_point);

float inv_out_scale = 1.0f / out_scale_f;
constexpr float min_val =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float max_val =
static_cast<float>(std::numeric_limits<int8_t>::max());

/* Tensor X exactly matches Y in shape, no broadcasting */
if (X_numel == Y_numel && Y_numel == out_numel) {
for (size_t i = 0; i < X_numel; ++i) {
float x = X_scale_f * (X_data[i] - X_zero_point_i32);
float y = Y_scale_f * (Y_data[i] - Y_zero_point_i32);
float z = x + y;
float tmp = roundf(z * inv_out_scale + out_zero_point_i32);
out_data[i] =
static_cast<int8_t>(std::max(std::min(tmp, max_val), min_val));
}
} /* if Y is a scalar Tensor */
else if (Y_numel == 1) {
float y =
kernels::dequantize<int8_t>(Y_data[0], Y_scale_f, Y_zero_point_i32);
for (size_t i = 0; i < X_numel; ++i) {
float x =
kernels::dequantize<int8_t>(X_data[i], X_scale_f, X_zero_point_i32);
float z = x + y;
out_data[i] =
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
}
} /* if X is a scalar Tensor */
else if (X_numel == 1) {
float x =
kernels::dequantize<int8_t>(X_data[0], X_scale_f, X_zero_point_i32);
for (size_t i = 0; i < Y_numel; ++i) {
float y =
kernels::dequantize<int8_t>(Y_data[i], Y_scale_f, Y_zero_point_i32);
float z = x + y;
out_data[i] =
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
}
} /* other broadcasting cases */
else {
/* Broadcasting implementation */
ssize_t X_dim = X.dim();
ssize_t Y_dim = Y.dim();
ssize_t out_dim = out.dim();

/* Precompute strides for X and Y tensors */
constexpr size_t max_dim = executorch::runtime::kTensorDimensionLimit;
size_t X_strides[max_dim] = {0};
size_t Y_strides[max_dim] = {0};
size_t X_stride_val = 1;
size_t Y_stride_val = 1;

/* Calculate strides from last dimension to first */
for (int d = out_dim - 1; d >= 0 && d >= out_dim - max_dim; --d) {
int idx = out_dim - 1 - d; /* Index into the fixed-size array */
if (d >= out_dim - X_dim) {
size_t X_d = d - (out_dim - X_dim);
X_strides[idx] = X_stride_val;
X_stride_val *= X.size(X_d);
}

if (d >= out_dim - Y_dim) {
size_t Y_d = d - (out_dim - Y_dim);
Y_strides[idx] = Y_stride_val;
Y_stride_val *= Y.size(Y_d);
}
}

/* Iterate over output tensor */
for (ssize_t i = 0; i < out_numel; ++i) {
size_t out_idx = i;
size_t X_idx = 0;
size_t Y_idx = 0;

/* Compute corresponding indices in input tensors */
for (int d = out_dim - 1; d >= 0; --d) {
size_t out_dim_idx = out_idx % out.size(d);
out_idx /= out.size(d);

/* Compute X index */
if (d >= out_dim - X_dim) {
size_t X_d = d - (out_dim - X_dim);
size_t X_dim_idx = out_dim_idx % X.size(X_d);
if (d >= out_dim - max_dim) {
int idx = out_dim - 1 - d;
X_idx += X_dim_idx * X_strides[idx];
} else {
size_t X_stride = 1;
for (int k = out_dim - 1; k > d; --k) {
if (k >= out_dim - X_dim) {
size_t X_k = k - (out_dim - X_dim);
X_stride *= X.size(X_k);
}
}
X_idx += X_dim_idx * X_stride;
}
}

/* Compute Y index */
if (d >= out_dim - Y_dim) {
size_t Y_d = d - (out_dim - Y_dim);
size_t Y_dim_idx = out_dim_idx % Y.size(Y_d);
if (d >= out_dim - max_dim) {
int idx = out_dim - 1 - d;
Y_idx += Y_dim_idx * Y_strides[idx];
} else {
size_t Y_stride = 1;
for (int k = out_dim - 1; k > d; --k) {
if (k >= out_dim - Y_dim) {
size_t Y_k = k - (out_dim - Y_dim);
Y_stride *= Y.size(Y_k);
}
}
Y_idx += Y_dim_idx * Y_stride;
}
}
}

/* Apply the operation */
float x = kernels::dequantize<int8_t>(
X_data[X_idx], X_scale_f, X_zero_point_i32);
float y = kernels::dequantize<int8_t>(
Y_data[Y_idx], Y_scale_f, Y_zero_point_i32);
float z = x + y;
out_data[i] =
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
}
}
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
Loading
Loading