From bf26f4cbd0f7ddc9b9779b951eec106a819e2e50 Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Mon, 9 Mar 2026 16:01:51 -0700 Subject: [PATCH 1/3] Add quantizer support for quantized_conv1d_ncl Summary: Update quantizer patterns to route conv1d operations to quantized_conv1d_ncl instead of quantized_conv2d_nchw: - Conv1dPattern.replacement_op() now returns quantized_conv1d_ncl - Conv1dReluPattern0 and Conv1dReluPattern1 now return quantized_conv1d_ncl - Added fused_conv1d_relu test and _build_conv1d_relu_graph helper Differential Revision: D95279330 --- backends/cadence/aot/quantizer/patterns.py | 8 ++- .../cadence/aot/tests/test_quantizer_ops.py | 55 +++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 7a11541b601..741fae68da7 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -267,7 +267,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + return torch.ops.cadence.quantized_conv1d_ncl.per_tensor class Conv2dPattern(QuantizationPattern): @@ -507,12 +507,18 @@ class Conv1dReluPattern0(ConvReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default] + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_conv1d_ncl.per_tensor + # Conv1d + alternate relu op fusion class Conv1dReluPattern1(ConvReluBasePattern): def partition_types(self) -> List[OpOverload]: return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default] + def replacement_op(self) -> OpOverload: + return torch.ops.cadence.quantized_conv1d_ncl.per_tensor + # Conv2d + regular relu op fusion class Conv2dReluPattern0(ConvReluBasePattern): diff --git a/backends/cadence/aot/tests/test_quantizer_ops.py b/backends/cadence/aot/tests/test_quantizer_ops.py index 831ab3b95b6..683df8e3e2b 100644 --- a/backends/cadence/aot/tests/test_quantizer_ops.py +++ b/backends/cadence/aot/tests/test_quantizer_ops.py @@ -207,6 +207,15 @@ [None, qconfig_A8W8.input_activation, qconfig_A8W8.weight], ), # CadenceFusedConvReluQuantizer test cases + ( + "fused_conv1d_relu_A8W8sym", + lambda self: self._build_conv1d_relu_graph(), + CadenceFusedConvReluQuantizer(), + torch.ops.aten.relu.default, + qconfig_A8W8sym.output_activation, + # For fused conv1d+relu: [input_activation, weight] from conv1d node + [qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight], + ), ( "fused_conv2d_relu_A8W8sym", lambda self: self._build_conv2d_relu_graph(), @@ -503,6 +512,52 @@ def _build_conv2d_relu_graph( return gm, relu_nodes[0], conv2d_nodes[0] + def _build_conv1d_relu_graph( + self, + ) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]: + """Build a graph with a conv1d followed by relu (fused pattern). + + Returns: + A tuple of (graph_module, relu_node, conv_node). + The relu_node is the target node where the annotation is placed. + The conv_node is the input source node whose args contain the quantized inputs. + """ + builder = GraphBuilder() + # Input shape: (batch, in_channels, length) + x = builder.placeholder("x", torch.randn(1, 3, 10)) + # Weight shape: (out_channels, in_channels, kernel_size) + weight = builder.placeholder("weight", torch.randn(6, 3, 3)) + conv1d = builder.call_operator( + op=torch.ops.aten.conv1d.default, + args=(x, weight), + meta=NodeMetadata( + {"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]} + ), + ) + relu = builder.call_operator( + op=torch.ops.aten.relu.default, + args=(conv1d,), + meta=NodeMetadata( + {"source_fn_stack": [("relu", torch.ops.aten.relu.default)]} + ), + ) + builder.output([relu]) + gm = builder.get_graph_module() + + relu_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.relu.default, + ) + self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node") + + conv1d_nodes = gm.graph.find_nodes( + op="call_function", + target=torch.ops.aten.conv1d.default, + ) + self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node") + + return gm, relu_nodes[0], conv1d_nodes[0] + @parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES) def test_quantizer_annotation( self, From b0b27740c3d71570b6ce2ca07c094a0d9816e2cc Mon Sep 17 00:00:00 2001 From: Matthias Cremon Date: Mon, 9 Mar 2026 16:01:51 -0700 Subject: [PATCH 2/3] Enable layout change and replacements of convs for 1D cases Differential Revision: D95492648 --- backends/cadence/aot/replace_ops.py | 105 ++++--- .../aot/tests/test_replace_ops_passes.py | 286 ++++++++++++++++-- 2 files changed, 323 insertions(+), 68 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 14a35c01baf..47c4b51499a 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -1023,6 +1023,7 @@ def targets(self) -> list[EdgeOpOverload]: exir_ops.edge.cadence.conv1d.default, exir_ops.edge.cadence.conv2d.default, exir_ops.edge.cadence.conv3d.default, + exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, ] @@ -1081,71 +1082,70 @@ def _change_depthwise_weight_to_hwc( inp_data_format=0 (NHWC), but the standard NCHW->NHWC permutation produces [OC, KH, KW, 1]. This function applies the correct permutation for depthwise convolution weights. + + For the 1D case, the C++ NLC kernels expect 3D weights in [OC, K, IC/groups] + format, so we use the standard NCHW->NHWC transpose (swap dims 1 and -1) + to produce [OC, K, 1]. """ - # For depthwise: input shape is either: + if not is_2d: + # 1D case: use standard transpose [OC, 1, K] -> [OC, K, 1] + # The C++ NLC kernels expect 3D weight in [OC, K, IC/groups] format. + return self._change_nchw_to_nhwc(graph, node) + # 2D case: [OC, 1, KH, KW], target is [KH, KW, OC] # Permute [0, 1, 2, 3] -> [2, 3, 0, 1] gives [KH, KW, OC, 1] # Then squeeze the last dim (which is 1) to get [KH, KW, OC] - if is_2d: - permute_indices = [2, 3, 0, 1] - permute_node = graph.call_function( - exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} - ) - permute_node.meta = node.meta - - # Squeeze the last dimension (which has size 1) - squeeze_node = graph.call_function( - exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {} - ) - squeeze_node.meta = node.meta - return squeeze_node - else: - # 1D case: [OC, 1, K], target is [K, OC] - # Permute [0, 1, 2] -> [1, 2, 0] gives [1, K, OC] - permute_indices = [1, 2, 0] - permute_node = graph.call_function( - exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} - ) - permute_node.meta = node.meta + permute_indices = [2, 3, 0, 1] + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} + ) + permute_node.meta = node.meta - # Squeeze the first dimension (which has size 1) - squeeze_node = graph.call_function( - exir_ops.edge.aten.squeeze_copy.dim, (permute_node, 0), {} - ) - squeeze_node.meta = node.meta - return squeeze_node + # Squeeze the last dimension (which has size 1) + squeeze_node = graph.call_function( + exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {} + ) + squeeze_node.meta = node.meta + return squeeze_node def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: assert isinstance(node.target, EdgeOpOverload) - quantized_op = ( - node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor - ) + quantized_op = node.target in { + exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + } - # Check if already in NHWC layout + # Check if already in NHWC/NLC layout if not quantized_op and len(node.args) == 8 and node.args[-1] is True: return False + # Get input shape to determine if it's 1D or 2D + input_node = cast(torch.fx.Node, node.args[0]) + input_shape = input_node.meta["val"].shape + is_2d = len(input_shape) == 4 + # Determine the new op target if quantized_op: - new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + if is_2d: + new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor + else: + new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor else: new_op = node.target graph = node.graph - # Get input and weight nodes - input_node = cast(torch.fx.Node, node.args[0]) + # Get weight node weight_node = cast(torch.fx.Node, node.args[1]) # Check if this is a depthwise convolution (groups == input_channels) # and weight is 4D with shape [OC, 1, KH, KW] groups = cast(int, node.args[6]) - input_shape = input_node.meta["val"].shape weight_shape = weight_node.meta["val"].shape input_channels = input_shape[1] # NCHW format, channels at index 1 # NCHW: also verify weight IC dim == 1. depthwise = is_depthwise_conv(groups, input_channels) and weight_shape[1] == 1 - is_2d = len(input_shape) == 4 + # Insert transpose operations before the node with graph.inserting_before(node): # Convert input from NCHW to NHWC @@ -1275,10 +1275,26 @@ class ReplaceConvWithIm2RowAndLinear(RemoveOrReplacePassInterface): exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default, exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default, exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default, + exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor, } + # Set of quantized conv ops + quantized_conv_ops: frozenset[EdgeOpOverload] = frozenset({ + exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, + exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + }) + + # Set of channel-last conv ops (NHWC for 2D, NLC for 1D) + channel_last_conv_ops: frozenset[EdgeOpOverload] = frozenset({ + exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor, + exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor, + }) + @property def targets(self) -> list[EdgeOpOverload]: return list(self.conv_op_to_linear_op.keys()) @@ -1286,10 +1302,7 @@ def targets(self) -> list[EdgeOpOverload]: def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the relevant args from convolution node. assert isinstance(node.target, EdgeOpOverload) - quantized_op = ( - node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor - or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor - ) + quantized_op = node.target in self.quantized_conv_ops assert (len(node.args) == 7 and not quantized_op) or ( len(node.args) >= 12 and quantized_op ), "Inconsistent args for convolution" @@ -1325,13 +1338,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: out_shape = node.meta["val"].shape assert None not in {weight_shape, out_shape} - # Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the - # channel_last layout is specified by the channel_last arg of conv - # op, which is either the last argument (15th) or implicitely False - # if the op is quantized, or the last argument if not. - channel_last = ( - node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor - ) + # Determine if the convolution is NCHW or NHWC (for 2D) / NCL or NLC (for 1D). + # Channel-last layouts are NHWC for 2D and NLC for 1D. + channel_last = node.target in self.channel_last_conv_ops # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 95d470644a0..e55892af8ff 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -1331,6 +1331,126 @@ def test_replace_conv2d_with_im2row_and_linear(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 ) + @torch.no_grad() + def test_replace_quantized_conv1d_ncl_with_im2row_and_linear(self) -> None: + # Input shape: [N, C, L] = [1, 2, 10] + x = torch.randint(-128, 127, (1, 2, 10), dtype=torch.int8) + # Weight shape: [out_channels, in_channels, kernel_size] = [3, 2, 4] + weights = torch.randint(-128, 127, (3, 2, 4), dtype=torch.int8) + bias = torch.randint(-1000, 1000, (3,), dtype=torch.int32) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor, + args=( + x, + weights, + bias, + [1], # stride + [0], # padding + [1], # dilation + 1, # groups + 0, # input_zero_point + 0, # weight_zero_point + 0.01, # bias_scale + 0.02, # out_scale + 0, # out_zero_point + 1, # out_multiplier + 0, # out_shift + ), + ) + + gm_before = copy.deepcopy(original_gm) + p = ReplaceConvWithIm2RowAndLinear() + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceConvWithIm2RowAndLinear_conv1d_ncl", + ) + + # Assert that the convolution is converted to im2row + quantized_linear + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor + ), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.per_tensor), 1 + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.quantized_linear.per_tensor + ), + 1, + ) + + @torch.no_grad() + def test_replace_quantized_conv1d_nlc_with_im2row_and_linear(self) -> None: + # Input shape: [N, L, C] = [1, 10, 2] (channel-last / NLC format) + x = torch.randint(-128, 127, (1, 10, 2), dtype=torch.int8) + # Weight shape: [out_channels, in_channels, kernel_size] = [3, 2, 4] + weights = torch.randint(-128, 127, (3, 2, 4), dtype=torch.int8) + bias = torch.randint(-1000, 1000, (3,), dtype=torch.int32) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor, + args=( + x, + weights, + bias, + [1], # stride + [0], # padding + [1], # dilation + 1, # groups + 0, # input_zero_point + 0, # weight_zero_point + 0.01, # bias_scale + 0.02, # out_scale + 0, # out_zero_point + 1, # out_multiplier + 0, # out_shift + ), + ) + + gm_before = copy.deepcopy(original_gm) + p = ReplaceConvWithIm2RowAndLinear() + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x, weights, bias] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceConvWithIm2RowAndLinear_conv1d_nlc", + ) + + # Assert that the convolution is converted to im2row + quantized_linear + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor + ), + 0, + ) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.cadence.im2row.per_tensor), 1 + ) + self.assertEqual( + count_node( + graph_after_passes, exir_ops.edge.cadence.quantized_linear.per_tensor + ), + 1, + ) + @expand( [ [(3, 1, 5), 1, 0], @@ -1951,6 +2071,69 @@ def create_quantized_convolution_graph_module( args=args, ) + def create_quantized_conv1d_graph_module( + self, channels_last: Optional[bool] = None + ) -> tuple[tuple[torch.Tensor, ...], torch.fx.GraphModule]: + """Helper to create a quantized conv1d node. + + quantized_conv1d_ncl/nlc.per_tensor( + Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, + int[] dilation, int groups, int input_zero_point, int weight_zero_point, + Tensor bias_scale, float out_scale, int out_zero_point, int out_multiplier, + int out_shift) -> (Tensor Z) + """ + # NCL: (N, C, L) format + # NLC: (N, L, C) format + in_channels = 3 + out_channels = 16 + kernel_size = 4 + if channels_last: + x = torch.randint(0, 100, (1, 224, in_channels), dtype=torch.int8) # NLC + w = torch.randint(-128, 127, (out_channels, kernel_size, in_channels), dtype=torch.int8) + else: + x = torch.randint(0, 100, (1, in_channels, 224), dtype=torch.int8) # NCL + w = torch.randint(-128, 127, (out_channels, in_channels, kernel_size), dtype=torch.int8) + b = torch.randint(-1000, 1000, (out_channels,), dtype=torch.int32) + stride = [2] + padding = [0] + dilation = [1] + groups = 1 + input_zero_point = 0 + w_zero_point = 0 + b_scale = 0.01 + out_scale = 0.02 + out_zero_point = 0 + out_multiplier = 1 + out_shift = 0 + args = ( + x, + w, + b, + stride, + padding, + dilation, + groups, + input_zero_point, + w_zero_point, + b_scale, + out_scale, + out_zero_point, + out_multiplier, + out_shift, + ) + if channels_last: + op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor + else: + op = exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor + + placeholders = (x, w, b) + + return placeholders, single_op_builder( + placeholders=placeholders, + op=op, + args=args, + ) + def test_quantized_convolution_default_channel_last(self) -> None: # Create a graph with a single convolution node. placeholders, gm = self.create_quantized_convolution_graph_module() @@ -1988,6 +2171,76 @@ def test_quantized_convolution_default_channel_last(self) -> None: "ReplaceConvWithChannelLastConvPass", ) + def test_convert_quantized_conv1d_ncl_to_nlc(self) -> None: + # Create a graph with a quantized_conv1d_ncl node + placeholders, gm = self.create_quantized_conv1d_graph_module(channels_last=False) + original = copy.deepcopy(gm) + # Check that we start with quantized_conv1d_ncl + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor), 1 + ) + self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + + # Apply replacement pass + p = ReplaceConvWithChannelLastConvPass() + gm_after_replacement = p.call(gm).graph_module + + # Verify the quantized_conv1d_nlc node exists + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor, + ), + 1, + ) + # For 1D conv, the pass uses transpose_copy.int (swap dims 1 and -1) + # for input, weight, and output: 3 transpose_copy ops total + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), + 3, + ) + + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + + def test_no_transpose_if_already_quantized_conv1d_channel_last + # Create a graph with a quantized_conv1d_nlc node (already channel-last) + placeholders, gm = self.create_quantized_conv1d_graph_module(channels_last=True) + original = copy.deepcopy(gm) + # Check if graph module has quantized_conv1d_nlc + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor), 1 + ) + + # Apply replacement pass + p = ReplaceConvWithChannelLastConvPass() + gm_after_replacement = p.call(gm).graph_module + + # Check that no replacement was made - nlc is not a target of this pass + # The pass doesn't target nlc, so it should remain unchanged + self.assertEqual( + count_node( + gm_after_replacement, + exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor, + ), + 1, + ) + # No permutes should be added + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 0 + ) + validate( + gm_after_replacement, + original, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. placeholders, gm = self.create_quantized_convolution_graph_module( @@ -2325,10 +2578,11 @@ def create_1d_depthwise_convolution_graph_module( return placeholders, gm def test_1d_depthwise_convolution_weight_shape(self) -> None: - """Test that 1D depthwise conv weight is transformed to [K, OC] format. + """Test that 1D depthwise conv weight is transformed to [OC, K, 1] format. For 1D depthwise conv with groups == in_channels > 1, the weight should be - transformed from [OC, 1, K] to [K, OC] (2D) via permute_copy + squeeze_copy. + transformed from [OC, 1, K] to [OC, K, 1] (3D) via transpose_copy.int, + matching the standard NLC weight format expected by C++ kernels. """ placeholders, gm = self.create_1d_depthwise_convolution_graph_module() self.assertEqual( @@ -2347,40 +2601,32 @@ def test_1d_depthwise_convolution_weight_shape(self) -> None: ) # For 1D depthwise: - # - Input/output: transpose_copy.int (2 ops, for 3D NCHW<->NHWC) - # - Weight: permute_copy.default + squeeze_copy.dim (depthwise layout) + # - Input/output/weight all use transpose_copy.int (3 ops) + # - No squeeze_copy or permute_copy needed self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), - 2, - ) - self.assertEqual( - count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), - 1, + 3, ) self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.aten.squeeze_copy.dim), - 1, + 0, ) for node in gm_after_replacement.graph.nodes: if node.target != exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: continue weight_node = node.args[1] - self.assertEqual( - weight_node.target, - exir_ops.edge.aten.squeeze_copy.dim, - "1D depthwise conv weight should be processed by squeeze_copy", - ) weight_shape = weight_node.meta["val"].shape self.assertEqual( len(weight_shape), - 2, - f"1D depthwise weight should be 2D [K, OC], got {len(weight_shape)}D", + 3, + f"1D depthwise weight should be 3D [OC, K, 1], got {len(weight_shape)}D", ) # Original weight: [8, 1, 3] (OC, 1, K) - # Expected after depthwise transform: [3, 8] (K, OC) - self.assertEqual(weight_shape[0], 3) # K - self.assertEqual(weight_shape[1], 8) # OC + # Expected after standard NLC transform: [8, 3, 1] (OC, K, IC/groups) + self.assertEqual(weight_shape[0], 8) # OC + self.assertEqual(weight_shape[1], 3) # K + self.assertEqual(weight_shape[2], 1) # IC/groups validate( gm, From e4c08a3d63c7cde5fa8e33cb7e435851ec8eadcd Mon Sep 17 00:00:00 2001 From: Reza Sajadiany Date: Tue, 17 Mar 2026 16:59:38 -0700 Subject: [PATCH 3/3] Conv1d channel last bug fix Summary: Fixes Conv1d channel las indexing issue. Differential Revision: D97009237 --- .../hifi/operators/op_quantized_conv1d_nlc.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc.cpp b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc.cpp index e40cca4a88a..376e06e4a8c 100644 --- a/backends/cadence/hifi/operators/op_quantized_conv1d_nlc.cpp +++ b/backends/cadence/hifi/operators/op_quantized_conv1d_nlc.cpp @@ -47,14 +47,14 @@ void xa_opt_quantized_conv1d_nlc_asym8sxsym8s_asym8s( (WORD32* __restrict__)bias.const_data_ptr(); WORD32 batches = input.size(0); - WORD32 input_channels = input.size(1); - WORD32 input_width = input.size(2); + WORD32 input_channels = input.size(2); + WORD32 input_width = input.size(1); WORD32 input_height = 1; WORD32 kernel_height = 1; WORD32 out_channels = weight.size(0); - WORD32 kernel_channels = weight.size(1); - WORD32 kernel_width = weight.size(2); - WORD32 out_width = out.size(2); + WORD32 kernel_channels = weight.size(2); + WORD32 kernel_width = weight.size(1); + WORD32 out_width = out.size(1); WORD32 out_height = 1; WORD32 x_stride = stride[1]; WORD32 y_stride = stride[0]; @@ -70,7 +70,7 @@ void xa_opt_quantized_conv1d_nlc_asym8sxsym8s_asym8s( WORD32 out_zero_bias = output_zero_point; - WORD32 out_data_format = 1; + WORD32 out_data_format = 0; WORD32 p_out_multiplier32[out_channels]; WORD32 p_out_shift32[out_channels];