diff --git a/backends/cadence/aot/functions.yaml b/backends/cadence/aot/functions.yaml index 3968f215602..196480931e0 100644 --- a/backends/cadence/aot/functions.yaml +++ b/backends/cadence/aot/functions.yaml @@ -339,6 +339,26 @@ - arg_meta: null kernel_name: impl::reference::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out +- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: impl::reference::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out + - func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 19249ef50a5..cf4c5a8fffb 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -350,6 +350,26 @@ - arg_meta: null kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out +- func: cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out + +- func: cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out + +- func: cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!) + kernels: + - arg_meta: null + kernel_name: cadence::impl::HiFi::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out + - func: cadence::quantized_layer_norm.out(Tensor input, Tensor in_scale, Tensor in_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) kernels: - arg_meta: null diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 52b688490b2..b88564e3ba5 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -168,6 +168,30 @@ lib.define( "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" ) +lib.define( + "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.define( + "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift) -> (Tensor Z)" +) +lib.define( + "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, *, Tensor(a!) out) -> Tensor(a!)" +) lib.define( "quantized_matmul_asym8uxasym8u_asym8u(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)" ) @@ -1165,6 +1189,182 @@ def quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_meta( return input.new_empty(output_size, dtype=input.dtype) +@register_fake("cadence::quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + out_channels, _, *kernel_size = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + False, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, False + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + out_channels, _, *kernel_size = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + False, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, False + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor") +def quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + out_channels, *kernel_size, _ = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + True, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, True + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + +@register_fake("cadence::quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor") +def quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_meta( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + stride: Tuple[int], + padding: Tuple[int], + dilation: Tuple[int], + groups: int, + in_zero_point: int, + weight_zero_point: int, + bias_scale: float, + output_scale: float, + output_zero_point: int, + out_multiplier: int, + out_shift: int, +) -> torch.Tensor: + out_channels, *kernel_size, _ = weight.shape + + in_size = input.shape + # Assert that the input tensor has at least 3 dimensions, and at most 6 + assert len(in_size) > 2 + assert len(in_size) < 6 + + # Compute the output tensor size + output_size = ( + get_conv1d_output_size( + in_size, + out_channels, + stride[1], + padding[1], + dilation[1], + kernel_size[0], + True, + ) + if len(in_size) == 3 + else get_conv2d_output_size( + in_size, out_channels, stride, padding, dilation, kernel_size, True + ) + ) + + return input.new_empty(output_size, dtype=input.dtype) + + @register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, diff --git a/backends/cadence/aot/tests/test_type_dispatch_passes.py b/backends/cadence/aot/tests/test_type_dispatch_passes.py index 1deebdfbb1c..f180c138ca4 100644 --- a/backends/cadence/aot/tests/test_type_dispatch_passes.py +++ b/backends/cadence/aot/tests/test_type_dispatch_passes.py @@ -495,3 +495,179 @@ def test_uint8_dispatch_quantized_add(self) -> None: ), 1, ) + + def test_int8_dispatch_quantized_conv_nchw_depthwise(self) -> None: + """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nchw""" + # Depthwise convolution: groups == input_channels + x = torch.randint(-128, 127, (1, 3, 8, 8), dtype=torch.int8) + w = torch.randint( + -128, 127, (3, 1, 3, 3), dtype=torch.int8 + ) # groups=3, input_channels=3 + b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, + args=( + x, + w, + b, + [1, 1], + [0, 0], + [1, 1], + 3, + 0, + 0, + 1.0, + 1.0, + 0, + 1, + 1, + ), # groups=3 + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), + 0, + ) + # Should be replaced with int8 depthwise specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch_quantized_conv_nchw_depthwise(self) -> None: + """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nchw""" + # Depthwise convolution: groups == input_channels + x = torch.randint(0, 255, (1, 3, 8, 8), dtype=torch.uint8) + w = torch.randint( + 0, 255, (3, 1, 3, 3), dtype=torch.uint8 + ) # groups=3, input_channels=3 + b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, + args=( + x, + w, + b, + [1, 1], + [0, 0], + [1, 1], + 3, + 0, + 0, + 1.0, + 1.0, + 0, + 1, + 1, + ), # groups=3 + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), + 0, + ) + # Should be replaced with uint8 depthwise specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u.per_tensor, + ), + 1, + ) + + def test_int8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: + """Test int8 x int8 inputs with depthwise should dispatch to depthwise_asym8sxsym8s_asym8s variant for quantized_conv_nhwc""" + # Depthwise convolution: groups == input_channels + x = torch.randint(-128, 127, (1, 8, 8, 3), dtype=torch.int8) + w = torch.randint( + -128, 127, (3, 3, 3, 1), dtype=torch.int8 + ) # groups=3, input_channels=3 + b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + args=( + x, + w, + b, + [1, 1], + [0, 0], + [1, 1], + 3, + 0, + 0, + 1.0, + 1.0, + 0, + 1, + 1, + ), # groups=3 + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), + 0, + ) + # Should be replaced with int8 depthwise specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor, + ), + 1, + ) + + def test_uint8_dispatch_quantized_conv_nhwc_depthwise(self) -> None: + """Test uint8 x uint8 inputs with depthwise should dispatch to depthwise_asym8uxasym8u_asym8u variant for quantized_conv_nhwc""" + # Depthwise convolution: groups == input_channels + x = torch.randint(0, 255, (1, 8, 8, 3), dtype=torch.uint8) + w = torch.randint( + 0, 255, (3, 3, 3, 1), dtype=torch.uint8 + ) # groups=3, input_channels=3 + b = torch.randint(-2147483648, 2147483647, (3,), dtype=torch.int32) + gm = single_op_builder( + placeholders=(x, w, b), + op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, + args=( + x, + w, + b, + [1, 1], + [0, 0], + [1, 1], + 3, + 0, + 0, + 1.0, + 1.0, + 0, + 1, + 1, + ), # groups=3 + ) + p = CompileTimeTypeDispatchPass() + gm = cast(PassResult, p(gm)).graph_module + # Original op should be replaced + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), + 0, + ) + # Should be replaced with uint8 depthwise specific variant + self.assertEqual( + count_node( + gm, + exir_ops.edge.cadence.quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u.per_tensor, + ), + 1, + ) diff --git a/backends/cadence/aot/type_dispatch.py b/backends/cadence/aot/type_dispatch.py index c53f62a45b7..ec9cecb03ed 100644 --- a/backends/cadence/aot/type_dispatch.py +++ b/backends/cadence/aot/type_dispatch.py @@ -126,12 +126,22 @@ def call_operator( exir_ops.edge.cadence.quantized_conv_nchw.per_tensor, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor, ]: + groups = args[6] + input_channels = ( + args[0].to_tensor().shape[1] + if op == exir_ops.edge.cadence.quantized_conv_nchw.per_tensor + else args[0].to_tensor().shape[-1] + ) + is_depthwise = groups == input_channels + dilation = args[5] # pyre-ignore[16]: None has no attribute '__iter__'. is_dilated = any(d > 1 for d in dilation) if is_dilated: type_suffix = f"dilated_{type_suffix}" + elif is_depthwise: + type_suffix = f"depthwise_{type_suffix}" typed_op_name = f"{base_name}_{type_suffix}" diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp index 2f60b249c94..6e09b995126 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -209,95 +209,8 @@ void xa_opt_quantized_conv_nchw_asym8sxsym8s_asym8s( return; } - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 1); // NCHW - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * - sizeof(WORD8)); - - WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = - p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 1, // NCHW - 0, // NHWC - p_scratch); - } - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = out_height; - p_inp_shape[2] = out_width; - p_inp_shape[3] = out_channels; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = out_channels; - p_out_shape[2] = out_height; - p_out_shape[3] = out_width; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; - - xa_nn_transpose_8_8( - p_out, - p_out_shape, - p_out_temp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - return; - } + // Depthwise convolutions are now handled by specialized operators + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); } void quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out( diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp index 6b5fd72d3fc..ccbf70e1d2d 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -209,95 +209,8 @@ void xa_opt_quantized_conv_nchw_asym8uxsym8u_asym8u( return; } - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 1); // NCHW - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( - ctx, - ((batches * out_channels * out_height * out_width) + 8) * - sizeof(UWORD8)); - - UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - UWORD8* out_batch = - p_out_temp + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)p_kernel, - (WORD8*)in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 1, // NCHW - 0, // NHWC - p_scratch); - } - - WORD32 p_inp_shape[kNnlibMaxDim]; - p_inp_shape[0] = batches; - p_inp_shape[1] = out_height; - p_inp_shape[2] = out_width; - p_inp_shape[3] = out_channels; - - WORD32 p_out_shape[kNnlibMaxDim]; - p_out_shape[0] = batches; - p_out_shape[1] = out_channels; - p_out_shape[2] = out_height; - p_out_shape[3] = out_width; - - WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; - - xa_nn_transpose_8_8( - (WORD8*)p_out, - p_out_shape, - (WORD8*)p_out_temp, - p_inp_shape, - p_permute_vec, - kNnlibMaxDim, - kNnlibMaxDim); - - return; - } + // Depthwise convolutions are now handled by specialized operators + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); } void quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out( diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..3e2c9c58401 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NCHW convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + WORD8* ptr1 = (WORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * sizeof(WORD8)); + + WORD8* p_out_temp = (WORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); +} + +void quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..103ce9568c5 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NCHW convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + constexpr int kNnlibMaxDim = 4; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 1); // NCHW + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + UWORD8* ptr1 = (UWORD8*)kernels::allocate_temp_memory( + ctx, + ((batches * out_channels * out_height * out_width) + 8) * sizeof(UWORD8)); + + UWORD8* p_out_temp = (UWORD8*)ALIGN_PTR(ptr1, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out_temp + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)p_kernel, + (WORD8*)in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 1, // NCHW + 0, // NHWC + p_scratch); + } + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = batches; + p_inp_shape[1] = out_height; + p_inp_shape[2] = out_width; + p_inp_shape[3] = out_channels; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = batches; + p_out_shape[1] = out_channels; + p_out_shape[2] = out_height; + p_out_shape[3] = out_width; + + WORD32 p_permute_vec[kNnlibMaxDim] = {0, 3, 1, 2}; + + xa_nn_transpose_8_8( + (WORD8*)p_out, + p_out_shape, + (WORD8*)p_out_temp, + p_inp_shape, + p_permute_vec, + kNnlibMaxDim, + kNnlibMaxDim); +} + +void quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp index ea30acd81dc..9416b8b7fd2 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -153,63 +153,8 @@ void xa_opt_quantized_conv_nhwc_asym8sxsym8s_asym8s( return; } - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 0); // NHWC - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - WORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - out_batch, - p_kernel, - in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); - } - return; - } + // Depthwise convolutions are now handled by specialized operators + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); } void quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out( diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp index 96ca8049989..97f7967a2ba 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -153,63 +153,8 @@ void xa_opt_quantized_conv_nhwc_asym8uxsym8u_asym8u( return; } - if (groups == input_channels) { - WORD32 channels_multiplier = out_channels / input_channels; - - scratch_size = xa_nn_conv2d_depthwise_getsize( - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - inp_precision, - 0); // NHWC - - scratch_size = scratch_size < 0 ? 0 : scratch_size; - - ptr_scratch = (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); - - p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); - - for (int _n = 0; _n < batches; _n++) { - UWORD8* in_batch = - p_inp + _n * input_channels * input_height * input_width; - UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; - - xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( - (WORD8*)out_batch, - (WORD8*)p_kernel, - (WORD8*)in_batch, - p_bias, - input_height, - input_width, - input_channels, - kernel_height, - kernel_width, - channels_multiplier, - x_stride, - y_stride, - x_padding, - y_padding, - out_height, - out_width, - input_zero_bias, - out_multiplier32, - out_shift32, - out_zero_bias, - 0, // NHWC - 0, // NHWC - p_scratch); - } - return; - } + // Depthwise convolutions are now handled by specialized operators + ET_CHECK_MSG(groups == 1, "Only groups=1 supported for regular convolution"); } void quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out( diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp new file mode 100644 index 00000000000..6512622f221 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NHWC convolution for int8 x int8 -> int8 +void xa_opt_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + + WORD8* __restrict__ p_out = + (WORD8* __restrict__)out.mutable_data_ptr(); + WORD8* __restrict__ p_inp = + (WORD8* __restrict__)input.const_data_ptr(); + WORD8* __restrict__ p_kernel = + (WORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 0); // NHWC + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + WORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + WORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + out_batch, + p_kernel, + in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC + 0, // NHWC + p_scratch); + } +} + +void quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp new file mode 100644 index 00000000000..d41a9c8d4b7 --- /dev/null +++ b/backends/cadence/hifi/operators/op_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out.cpp @@ -0,0 +1,173 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1))) + +using Tensor = executorch::aten::Tensor; +using KernelRuntimeContext = torch::executor::KernelRuntimeContext; +using ScalarType = executorch::aten::ScalarType; +using ::executorch::aten::IntArrayRef; + +namespace cadence { +namespace impl { +namespace HiFi { +namespace native { + +// Specialized depthwise NHWC convolution for uint8 x uint8 -> uint8 +void xa_opt_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int16_t groups, + int32_t in_zero_point, + int32_t weight_zero_point, + float bias_scale, + float output_scale, + int32_t output_zero_point, + Tensor& out) { + bool conv1d = input.dim() == 3; + + UWORD8* __restrict__ p_out = + (UWORD8* __restrict__)out.mutable_data_ptr(); + UWORD8* __restrict__ p_inp = + (UWORD8* __restrict__)input.const_data_ptr(); + UWORD8* __restrict__ p_kernel = + (UWORD8* __restrict__)weight.const_data_ptr(); + WORD32* __restrict__ p_bias = + (WORD32* __restrict__)bias.const_data_ptr(); + + WORD32 input_height = conv1d ? 1 : input.size(2); + WORD32 input_width = conv1d ? input.size(2) : input.size(3); + WORD32 input_channels = input.size(1); + WORD32 kernel_height = conv1d ? 1 : weight.size(2); + WORD32 kernel_width = conv1d ? weight.size(2) : weight.size(3); + WORD32 out_channels = weight.size(0); + WORD32 out_height = conv1d ? 1 : out.size(2); + WORD32 out_width = conv1d ? out.size(2) : out.size(3); + WORD32 batches = input.size(0); + + WORD32 x_stride = stride[1]; + WORD32 y_stride = stride[0]; + WORD32 x_padding = padding[1]; + WORD32 y_padding = padding[0]; + + WORD32 input_zero_bias = -in_zero_point; + WORD32 out_zero_bias = output_zero_point; + WORD32 inp_precision = 8; + + WORD32 channels_multiplier = out_channels / input_channels; + + WORD32 out_multiplier32[out_channels]; + WORD32 out_shift32[out_channels]; + + float out_scale = 1. / output_scale; + + for (int i = 0; i < out_channels; i++) { + out_multiplier32[i] = bias_scale * out_scale * 2147483648; + out_shift32[i] = 0; + } + + WORD32 scratch_size = xa_nn_conv2d_depthwise_getsize( + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + inp_precision, + 0); // NHWC + + scratch_size = scratch_size < 0 ? 0 : scratch_size; + + WORD32* ptr_scratch = + (WORD32*)kernels::allocate_temp_memory(ctx, scratch_size); + pVOID p_scratch = (pVOID)ALIGN_PTR(ptr_scratch, 8); + + for (int _n = 0; _n < batches; _n++) { + UWORD8* in_batch = p_inp + _n * input_channels * input_height * input_width; + UWORD8* out_batch = p_out + _n * out_channels * out_height * out_width; + + xa_nn_conv2d_depthwise_per_chan_sym8sxasym8s( + (WORD8*)out_batch, + (WORD8*)p_kernel, + (WORD8*)in_batch, + p_bias, + input_height, + input_width, + input_channels, + kernel_height, + kernel_width, + channels_multiplier, + x_stride, + y_stride, + x_padding, + y_padding, + out_height, + out_width, + input_zero_bias, + out_multiplier32, + out_shift32, + out_zero_bias, + 0, // NHWC + 0, // NHWC + p_scratch); + } +} + +void quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + xa_opt_quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u( + ctx, + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +} // namespace native +} // namespace HiFi +} // namespace impl +} // namespace cadence diff --git a/backends/cadence/hifi/operators/targets.bzl b/backends/cadence/hifi/operators/targets.bzl index ebed546117e..3dc09b21ae2 100644 --- a/backends/cadence/hifi/operators/targets.bzl +++ b/backends/cadence/hifi/operators/targets.bzl @@ -66,11 +66,15 @@ OPERATORS = [ "quantized_conv_nchw_out", "quantized_conv_nchw_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nchw_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nhwc_out", "quantized_conv_nhwc_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nhwc_asym8uxsym8u_asym8u_per_tensor_out", + "quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out", + "quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out", "quantized_conv_nhwc_dilated_asym8sxsym8s_asym8s_per_tensor_out", "quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out", "quantized_fully_connected_out", diff --git a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp b/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp index 6979d8664b2..aefa75d7047 100644 --- a/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_nchw_out.cpp @@ -430,6 +430,72 @@ void quantized_conv_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out( out); } +void quantized_conv_nchw_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +void quantized_conv_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nchw( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + } // namespace native } // namespace reference } // namespace impl diff --git a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp b/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp index 1a1642f5fa6..26fbc86d5b0 100644 --- a/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp +++ b/backends/cadence/reference/operators/quantized_conv_nhwc_out.cpp @@ -417,6 +417,72 @@ void quantized_conv_nhwc_dilated_asym8uxsym8u_asym8u_per_tensor_out( out); } +void quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + +void quantized_conv_nhwc_depthwise_asym8uxsym8u_asym8u_per_tensor_out( + __ET_UNUSED KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& weight, + const Tensor& bias, + IntArrayRef stride, + IntArrayRef padding, + IntArrayRef dilation, + int64_t groups, + int64_t in_zero_point, + int64_t weight_zero_point, + double bias_scale, + double output_scale, + int64_t output_zero_point, + __ET_UNUSED int64_t out_multiplier, + __ET_UNUSED int64_t out_shift, + Tensor& out) { + quantized_conv_nhwc( + input, + weight, + bias, + stride, + padding, + dilation, + groups, + in_zero_point, + weight_zero_point, + bias_scale, + output_scale, + output_zero_point, + out); +} + } // namespace native } // namespace reference } // namespace impl