diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 196480931e0..c8e7d6cb3fc 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -359,6 +359,26 @@ - arg_meta: null kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out +- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out + - func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index cf4c5a8fffb..1b62c215ab6 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -370,6 +370,26 @@ - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out +- func: cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out + - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 35b4cbf3902..efb22a9e7d6 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -169,6 +169,30 @@ lib.define( "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" ) @@ -2153,6 +2177,150 @@ def roi_align_box_processor_meta( return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8) +@register_fake("cadence::quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) + out_channels, _, kernel_size = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + False, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) + out_channels, _, kernel_size = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + False, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.int8 + and weight.dtype == torch.int8 + and bias.dtype == torch.int32 + ) + out_channels, kernel_size, _ = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + True, + ) + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + assert input.dim() == 3 and weight.dim() == 3 + assert ( + input.dtype == torch.uint8 + and weight.dtype == torch.uint8 + and bias.dtype == torch.int32 + ) + out_channels, kernel_size, _ = weight.shape + output_size = get_conv1d_output_size( + input.shape, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size, + True, + ) + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::_softmax_f32_f32") def softmax_f32_f32_meta( self: torch.Tensor, diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 40ae6d23085..6d5b2a89c05 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -613,6 +613,7 @@ def quantized_conv_variant( layout: str, input_dtype: torch.dtype, weight_dtype: torch.dtype, + is_1d: bool = False, ) -> Callable[[Callable[..., torch.Tensor]], Callable[..., torch.Tensor]]: """Create a quantized conv variant with type checking.""" @@ -644,6 +645,14 @@ def variant( bias.dtype == torch.int32 ), f"Expected bias dtype int32, got {bias.dtype}" + if is_1d: + assert ( + len(input_tensor.shape) == 3 + ), f"1D convolution requires 3D input tensor, got {len(input_tensor.shape)}D" + assert ( + len(weight.shape) == 3 + ), f"1D convolution requires 3D weight tensor, got {len(weight.shape)}D" + # Call the appropriate base function match layout: case "nchw": @@ -748,6 +757,26 @@ def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor() -> torch.Tens def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... +@impl(m, "quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor") +@quantized_conv_variant("nchw", torch.int8, torch.int8, is_1d=True) +def quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor") +@quantized_conv_variant("nchw", torch.uint8, torch.uint8, is_1d=True) +def quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor") +@quantized_conv_variant("nhwc", torch.int8, torch.int8, is_1d=True) +def quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor() -> torch.Tensor: ... + + +@impl(m, "quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor") +@quantized_conv_variant("nhwc", torch.uint8, torch.uint8, is_1d=True) +def quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor() -> torch.Tensor: ... + + def quantized_relu_common( X: torch.Tensor, X_zero_point: torch.Tensor | int, diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index 52904aecb41..704d92a3197 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -446,6 +446,110 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None: 1, ) + def test_int8_dispatch_quantized_conv_nchw_1d(self) -> None: + """Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nchw""" + x = torch.randint(-128, 127, (1, 3, 8), dtype=torch.int8) + w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8) + b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), + 0, + ) + # Should be replaced with 1D int8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv1d_nchw_asym8sxsym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch_quantized_conv_nchw_1d(self) -> None: + """Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" + x = torch.randint(0, 255, (1, 3, 8), dtype=torch.uint8) + w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8) + b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), + 0, + ) + # Should be replaced with 1D uint8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv1d_nchw_asym8uxsym8u_asym8u.per_tensor, + ), + 1, + ) + + def test_int8_dispatch_quantized_conv_nhwc_1d(self) -> None: + """Test int8 x int8 inputs for 1D conv should dispatch to 1d_asym8sxasym8s_asym8s variant for quantized_conv_nhwc""" + x = torch.randint(-128, 127, (1, 8, 3), dtype=torch.int8) + w = torch.randint(-128, 127, (16, 3, 3), dtype=torch.int8) + b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), + 0, + ) + # Should be replaced with 1D int8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8sxsym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch_quantized_conv_nhwc_1d(self) -> None: + """Test uint8 x uint8 inputs for 1D conv should dispatch to 1d_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" + x = torch.randint(0, 255, (1, 8, 3), dtype=torch.uint8) + w = torch.randint(0, 255, (16, 3, 3), dtype=torch.uint8) + b = torch.randint(-2147483648, 2147483647, (16,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + args=(x, w, b, [1, 1], [0, 0], [1, 1], 1, 0, 0, 1.0, 1.0, 0, 1, 1), + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), + 0, + ) + # Should be replaced with 1D uint8 specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv1d_nhwc_asym8uxsym8u_asym8u.per_tensor, + ), + 1, + ) + def test_int8_dispatch_quantized_add(self) -> None: """Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add""" x = torch.randint(-128, 127, (2, 3), dtype=torch.int8) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index 108c4fb1a92..958a78a4808 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -129,6 +129,8 @@ def call_operator( type_suffix = config.type_dispatch_suffixes[dtype_key] base_name = config.base_name + typed_op_name = f"{base_name}_{type_suffix}" + if op in [ exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, @@ -140,17 +142,18 @@ def call_operator( else args[0].to_tensor().shape[-1] ) is_depthwise = groups == input_channels - - dilation = args[5] # pyre-ignore[16]: None has no attribute '__iter__'. - is_dilated = any(d > 1 for d in dilation) - - if is_dilated: - type_suffix = f"dilated_{type_suffix}" - elif is_depthwise: - type_suffix = f"depthwise_{type_suffix}" - - typed_op_name = f"{base_name}_{type_suffix}" + is_dilated = any(d > 1 for d in args[5]) + is_1d = len(args[0].to_tensor().shape) == 3 + + if is_depthwise: + typed_op_name = f"{base_name}_depthwise_{type_suffix}" + elif is_dilated: + typed_op_name = f"{base_name}_dilated_{type_suffix}" + elif is_1d and groups == 1: + typed_op_name = ( + f"quantized_conv1d_{base_name.split('_')[-1]}_{type_suffix}" + ) typed_op = getattr( getattr(exir_ops.edge.cadence, typed_op_name), config.variant diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..c1b5a1836a3 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW 1D convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv1d_nchw_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + constexpr int kNnlibMaxDim = 3; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_channels = weight.size(1); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + WORD32 kernel_zero_bias = -weight_zero_point; + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 1; + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, ((batches * input_channels * input_width) + 8) * sizeof(WORD8)); + WORD8* ptr2 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(WORD8)); + WORD8* pin = (WORD8*)ALIGN_PTR(ptr1, 8); + WORD8* pkernel = (WORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = input_width; + p_out_shape[2] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + + xa_nn_transpose_8_8( + pin, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_width; + p_out_shape1[2] = kernel_channels; + + xa_nn_transpose_8_8( + pkernel, + p_out_shape1, + p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = pin + _n * input_channels * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8xasym8( + out_batch, + in_batch, + pkernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nchw_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..fae49ec97c7 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,189 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NCHW 1D convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv1d_nchw_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + constexpr int kNnlibMaxDim = 3; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_channels = weight.size(1); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + WORD32 kernel_zero_bias = -weight_zero_point; + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 1; + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, ((batches * input_channels * input_width) + 8) * sizeof(UWORD8)); + UWORD8* ptr2 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((out_channels * kernel_channels * kernel_width) + 8) * sizeof(UWORD8)); + UWORD8* pin = (UWORD8*)ALIGN_PTR(ptr1, 8); + UWORD8* pkernel = (UWORD8*)ALIGN_PTR(ptr2, 8); + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = input_channels; + p_inp_shape[2] = input_width; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = input_width; + p_out_shape[2] = input_channels; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 2, 1}; + + xa_nn_transpose_8_8( + pin, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 p_inp_shape1[kNnlibMaxDim]; + p_inp_shape1[0] = out_channels; + p_inp_shape1[1] = kernel_channels; + p_inp_shape1[2] = kernel_width; + + WORD32 p_out_shape1[kNnlibMaxDim]; + p_out_shape1[0] = out_channels; + p_out_shape1[1] = kernel_width; + p_out_shape1[2] = kernel_channels; + + xa_nn_transpose_8_8( + pkernel, + p_out_shape1, + p_kernel, + p_inp_shape1, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); + + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = pin + _n * input_channels * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8uxasym8u( + out_batch, + in_batch, + pkernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nchw_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..a2cb591b3a7 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC 1D convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv1d_nhwc_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + WORD32 kernel_zero_bias = -weight_zero_point; + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 0; + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8xasym8( + out_batch, + in_batch, + p_kernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nhwc_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..441952ca189 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,138 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Optimized NHWC 1D convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv1d_nhwc_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 batches = input.size(0); + WORD32 input_channels = input.size(1); + WORD32 input_width = input.size(2); + WORD32 out_channels = weight.size(0); + WORD32 kernel_width = weight.size(2); + WORD32 out_width = out.size(2); + WORD32 x_stride = stride[1]; + WORD32 x_padding = padding[1]; + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + WORD32 kernel_zero_bias = -weight_zero_point; + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 out_zero_bias = output_zero_point; + WORD32 out_data_format = 0; + WORD32 scratch_size = + xa_nn_conv1d_std_getsize(kernel_width, input_width, input_channels, 8); + scratch_size = scratch_size < 0 ? 0 : scratch_size; + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_width; + + xa_nn_conv1d_std_asym8uxasym8u( + out_batch, + in_batch, + p_kernel, + p_bias, + 1, + input_width, + input_channels, + kernel_width, + out_channels, + x_stride, + x_padding, + out_width, + input_zero_bias, + kernel_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + out_data_format, + p_scratch); + } +} + +void quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + __ET_UNUSED IntArrayRef dilation, + __ET_UNUSED int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv1d_nhwc_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index d310396c262..fa263d4017c 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -66,6 +66,8 @@ OPERATORS = [ "quantized_conv_nchw_out", "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out", @@ -73,6 +75,8 @@ OPERATORS = [ "quantized_conv_nhwc_out", "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out", diff --git a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp b/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp index aefa75d7047..1a4faeed250 100644 --- a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp @@ -496,6 +496,72 @@ void quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( out); } +void quantized_conv1d_nchw_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +void quantized_conv1d_nchw_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + } // namespace native } // namespace reference } // namespace impl diff --git a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp b/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp index 26fbc86d5b0..21b17fb0724 100644 --- a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp @@ -417,6 +417,72 @@ void quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( out); } +void quantized_conv1d_nhwc_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +void quantized_conv1d_nhwc_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + void quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( __ET_UNUSED KernelRuntimeContext& ctx, const Tensor& input,