Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,16 @@
- arg_meta: null
kernel_name: impl::reference::quantized_linear_per_tensor_out

- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::im2row.out(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, Tensor in_zero_point, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
10 changes: 10 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,16 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out

- func: cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_linear_asym8sxasym8s_asym8s_per_tensor_out

- func: cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_linear_asym8uxasym8u_asym8u_per_tensor_out

- func: cadence::quantized_relu_per_tensor.out(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
60 changes: 60 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,26 @@
lib.define(
"quantized_linear.per_tensor_out(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_linear_asym8sxasym8s_asym8s.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_linear_asym8uxasym8u_asym8u.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_linear.per_tensor(Tensor src, Tensor weight, Tensor bias, SymInt src_zero_point, "
"SymInt weight_zero_point, SymInt out_multiplier, SymInt out_shift, SymInt out_zero_point, Tensor? offset) -> Tensor"
)
lib.define(
"quantized_linear_asym8sxasym8s_asym8s.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define(
"quantized_linear_asym8uxasym8u_asym8u.per_tensor(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
"int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)

lib.define(
"quantized_relu(Tensor X, Tensor X_zero_point, int out_zero_point, Tensor out_multiplier, Tensor out_shift) -> (Tensor Y)"
Expand Down Expand Up @@ -446,6 +462,50 @@ def quantized_linear_per_tensor_meta(
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::quantized_linear_asym8sxasym8s_asym8s.per_tensor")
def quantized_linear_asym8sxasym8s_asym8s_per_tensor_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
offset: Optional[torch.Tensor],
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::quantized_linear_asym8uxasym8u_asym8u.per_tensor")
def quantized_linear_asym8uxasym8u_asym8u_per_tensor_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_zero_point: int,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
out_zero_point: int,
offset: Optional[torch.Tensor],
) -> torch.Tensor:
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
# output comes in empty with shape [leading_dims, out_dim]
out_size = list(src.size())
weight_size = list(weight.size())
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=src.dtype)


@register_fake("cadence::quantized_conv")
def quantized_conv_meta(
input: torch.Tensor,
Expand Down
56 changes: 54 additions & 2 deletions backends/cadence/aot/tests/test_type_dispatch_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


class TestTypeDispatchPasses(unittest.TestCase):
def test_int8_dispatch(self) -> None:
def test_int8_dispatch_quantized_fully_connected(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
Expand All @@ -44,7 +44,7 @@ def test_int8_dispatch(self) -> None:
1,
)

def test_uint8_dispatch(self) -> None:
def test_uint8_dispatch_quantized_fully_connected(self) -> None:
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant"""
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
Expand All @@ -70,6 +70,58 @@ def test_uint8_dispatch(self) -> None:
1,
)

def test_int8_dispatch_quantized_linear(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_linear"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
w = torch.randint(-128, 127, (4, 3), dtype=torch.int8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_linear.per_tensor,
args=(x, w, b, 0, 0, 1, 0, 0, None),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor),
0,
)
# Should be replaced with int8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_linear_asym8sxasym8s_asym8s.per_tensor,
),
1,
)

def test_uint8_quantized_linear_dispatch(self) -> None:
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_linear"""
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
w = torch.randint(0, 255, (4, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (4,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_linear.per_tensor,
args=(x, w, b, 0, 0, 1, 0, 0, None),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_linear.per_tensor),
0,
)
# Should be replaced with uint8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_linear_asym8uxasym8u_asym8u.per_tensor,
),
1,
)

def test_mixed_types_error(self) -> None:
"""Test mixed int8/uint8 inputs should raise RuntimeError"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
Expand Down
52 changes: 26 additions & 26 deletions backends/cadence/aot/type_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,40 @@ class CompileTimeTypeDispatchPass(ExportPass):
Replaces generic ops with ops that have explicit types.
"""

_TYPE_DISPATCH_MAP: dict[tuple[torch.dtype, torch.dtype], str] = {
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
}

_SUPPORTED_OPS: dict[OpOverload, str] = {
exir_ops.edge.cadence.quantized_fully_connected.per_tensor: "quantized_fully_connected",
exir_ops.edge.cadence.quantized_linear.per_tensor: "quantized_linear",
}

def call_operator(
self,
op: OpOverload,
args: tuple[Argument, ...],
kwargs: dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
if op not in {
exir_ops.edge.cadence.quantized_fully_connected.per_tensor,
}:
if op not in self._SUPPORTED_OPS:
return super().call_operator(op, args, kwargs, meta)

if (
# pyre-ignore[16]: None has no attribute `to_tensor`.
args[0].to_tensor().dtype == torch.int8
and args[1].to_tensor().dtype == torch.int8
):
return super().call_operator(
exir_ops.edge.cadence.quantized_fully_connected_asym8sxasym8s_asym8s.per_tensor,
args,
kwargs,
meta,
)
elif (
args[0].to_tensor().dtype == torch.uint8
and args[1].to_tensor().dtype == torch.uint8
):
return super().call_operator(
exir_ops.edge.cadence.quantized_fully_connected_asym8uxasym8u_asym8u.per_tensor,
args,
kwargs,
meta,
)
else:
# pyre-ignore[16]: None has no attribute `to_tensor`.
input_dtype = args[0].to_tensor().dtype
weight_dtype = args[1].to_tensor().dtype
dtype_pair = (input_dtype, weight_dtype)

if dtype_pair not in self._TYPE_DISPATCH_MAP:
raise RuntimeError(
f"Unsupported input types for {op}: {args[0].to_tensor().dtype} and {args[1].to_tensor().dtype}"
f"Unsupported input types for {op}: {input_dtype} and {weight_dtype}"
)

base_op_name = self._SUPPORTED_OPS[op]
type_suffix = self._TYPE_DISPATCH_MAP[dtype_pair]

typed_op_name = f"{base_op_name}_{type_suffix}"
typed_op = getattr(exir_ops.edge.cadence, typed_op_name).per_tensor

return super().call_operator(typed_op, args, kwargs, meta)
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <xa_nnlib_kernels_api.h>

namespace cadence {
namespace impl {
namespace HiFi {
namespace native {

using ::executorch::aten::Tensor;
using ::executorch::runtime::getLeadingDims;
using ::executorch::runtime::KernelRuntimeContext;
using std::optional;

void quantized_linear_asym8sxasym8s_asym8s_per_tensor_out(
__ET_UNUSED KernelRuntimeContext& ctx,
const Tensor& in,
const Tensor& weight,
const Tensor& bias,
int64_t in_zero_point,
int64_t weight_zero_point,
int64_t out_multiplier,
int64_t out_shift,
int64_t out_zero_point,
__ET_UNUSED const optional<Tensor>& offset,
Tensor& out) {
// input comes in shape [leading_dims, in_dim]
// weight comes in shape [out_dim, in_dim]
// output comes in empty with shape [leading_dims, out_dim]
// Perform matrix multiply (M x N) x (N x P)' => M x P
const int64_t leading_dims = getLeadingDims(in, in.dim() - 1);
const int64_t out_dim = weight.size(0); // = out_dim
const int64_t in_dim = weight.size(1); // = in_dim

const int8_t* __restrict__ in_data = in.const_data_ptr<int8_t>();
const int8_t* __restrict__ weight_data = weight.const_data_ptr<int8_t>();
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();

const int32_t out_multipler_int32 = static_cast<int32_t>(out_multiplier);
const int32_t out_shift_int32 = static_cast<int32_t>(out_shift);

// The nnlib kernel to compute quantized linear via matmul.
const int32_t ret = xa_nn_matmul_asym8sxasym8s_asym8s(
out_data, // p_out
weight_data, // p_mat1,
in_data, // p_mat2,
bias_data, // p_bias
out_dim, // rows of p_mat1
in_dim, // cols of p_mat1
in_dim, // row_stride of p_mat1
leading_dims, // vec_count, i.e., rows of p_mat2
in_dim, // vec_offset of p_mat2.
out_dim, // out_offset, i.e., offset of next output element written
1, // out_stride, i.e., stride to go to next output row
-weight_zero_point, // mat1_zero_bias
-in_zero_point, // mat2_zero_bias
out_multipler_int32, // out_multiplier
out_shift_int32, // out_shift
out_zero_point); // out_zero_bias
ET_DCHECK_MSG(ret == 0, "HiFi quantized::linear_per_tensor failed");
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
Loading
Loading