Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,26 @@
- arg_meta: null
kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: impl::reference::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
20 changes: 20 additions & 0 deletions backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,26 @@
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out

- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out

- func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)
kernels:
- arg_meta: null
Expand Down
168 changes: 168 additions & 0 deletions backends/cadence/aot/ops_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,30 @@
lib.define(
"quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
lib.define(
"quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)"
)
lib.define(
"quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)"
)
Expand Down Expand Up @@ -2153,6 +2177,150 @@ def roi_align_box_processor_meta(
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)


@register_fake("cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, _, kernel_size = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
False,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, _, kernel_size = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
False,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
assert (
input.dtype == torch.int8
and weight.dtype == torch.int8
and bias.dtype == torch.int32
)
out_channels, kernel_size, _ = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
True,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: int,
bias_scale: float,
output_scale: float,
output_zero_point: int,
out_multiplier: int,
out_shift: int,
) -> torch.Tensor:
assert input.dim() == 3 and weight.dim() == 3
assert (
input.dtype == torch.uint8
and weight.dtype == torch.uint8
and bias.dtype == torch.int32
)
out_channels, kernel_size, _ = weight.shape
output_size = get_conv1d_output_size(
input.shape,
out_channels,
stride[1],
padding[1],
dilation[1],
kernel_size,
True,
)
return input.new_empty(output_size, dtype=input.dtype)


@register_fake("cadence::_softmax_f32_f32")
def softmax_f32_f32_meta(
self: torch.Tensor,
Expand Down
29 changes: 29 additions & 0 deletions backends/cadence/aot/ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def quantized_conv_variant(
layout: str,
input_dtype: torch.dtype,
weight_dtype: torch.dtype,
is_1d: bool = False,
) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]:
"""Create a quantized conv variant with type checking."""

Expand Down Expand Up @@ -644,6 +645,14 @@ def variant(
bias.dtype == torch.int32
), f"Expected bias dtype int32, got {bias.dtype}"

if is_1d:
assert (
len(input_tensor.shape) == 3
), f"1D convolution requires 3D input tensor, got {len(input_tensor.shape)}D"
assert (
len(weight.shape) == 3
), f"1D convolution requires 3D weight tensor, got {len(weight.shape)}D"

# Call the appropriate base function
match layout:
case "nchw":
Expand Down Expand Up @@ -748,6 +757,26 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens
def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True)
def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True)
def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor")
@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True)
def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ...


@impl(m, "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor")
@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True)
def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ...


def quantized_relu_common(
X: torch.Tensor,
X_zero_point: torch.Tensor | int,
Expand Down
104 changes: 104 additions & 0 deletions backends/cadence/aot/tests/test_type_dispatch_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,110 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
1,
)

def test_int8_dispatch_quantized_conv_nchw_1d(self) -> None:
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nchw"""
x = torch.randint(-128, 127, (1, 3, 8), dtype=torch.int8)
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
0,
)
# Should be replaced with 1D int8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor,
),
1,
)

def test_uint8_dispatch_quantized_conv_nchw_1d(self) -> None:
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nchw"""
x = torch.randint(0, 255, (1, 3, 8), dtype=torch.uint8)
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor),
0,
)
# Should be replaced with 1D uint8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor,
),
1,
)

def test_int8_dispatch_quantized_conv_nhwc_1d(self) -> None:
"""Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nhwc"""
x = torch.randint(-128, 127, (1, 8, 3), dtype=torch.int8)
w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8)
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
0,
)
# Should be replaced with 1D int8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor,
),
1,
)

def test_uint8_dispatch_quantized_conv_nhwc_1d(self) -> None:
"""Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nhwc"""
x = torch.randint(0, 255, (1, 8, 3), dtype=torch.uint8)
w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8)
b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32)
gm = single_op_builder(
placeholders=(x, w, b),
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1),
)
p = CompileTimeTypeDispatchPass()
gm = cast(PassResult, p(gm)).graph_module
# Original op should be replaced
self.assertEqual(
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor),
0,
)
# Should be replaced with 1D uint8 specific variant
self.assertEqual(
count_node(
gm,
exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor,
),
1,
)

def test_int8_dispatch_quantized_add(self) -> None:
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add"""
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
Expand Down
Loading
Loading