Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/cadence/aot/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def get_anchors(
)

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv2d_nchw.per_tensor
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor


class Conv2dPattern(QuantizationPattern):
Expand Down Expand Up @@ -507,12 +507,18 @@ class Conv1dReluPattern0(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor


# Conv1d + alternate relu op fusion
class Conv1dReluPattern1(ConvReluBasePattern):
def partition_types(self) -> List[OpOverload]:
return [torch.ops.aten.conv1d.default, torch.ops.aten.relu_.default]

def replacement_op(self) -> OpOverload:
return torch.ops.cadence.quantized_conv1d_ncl.per_tensor


# Conv2d + regular relu op fusion
class Conv2dReluPattern0(ConvReluBasePattern):
Expand Down
105 changes: 57 additions & 48 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ def targets(self) -> list[EdgeOpOverload]:
exir_ops.edge.cadence.conv1d.default,
exir_ops.edge.cadence.conv2d.default,
exir_ops.edge.cadence.conv3d.default,
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
]

Expand Down Expand Up @@ -1081,71 +1082,70 @@ def _change_depthwise_weight_to_hwc(
inp_data_format=0 (NHWC), but the standard NCHW->NHWC permutation produces
[OC, KH, KW, 1]. This function applies the correct permutation for depthwise
convolution weights.

For the 1D case, the C++ NLC kernels expect 3D weights in [OC, K, IC/groups]
format, so we use the standard NCHW->NHWC transpose (swap dims 1 and -1)
to produce [OC, K, 1].
"""
# For depthwise: input shape is either:
if not is_2d:
# 1D case: use standard transpose [OC, 1, K] -> [OC, K, 1]
# The C++ NLC kernels expect 3D weight in [OC, K, IC/groups] format.
return self._change_nchw_to_nhwc(graph, node)

# 2D case: [OC, 1, KH, KW], target is [KH, KW, OC]
# Permute [0, 1, 2, 3] -> [2, 3, 0, 1] gives [KH, KW, OC, 1]
# Then squeeze the last dim (which is 1) to get [KH, KW, OC]
if is_2d:
permute_indices = [2, 3, 0, 1]
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)
permute_node.meta = node.meta

# Squeeze the last dimension (which has size 1)
squeeze_node = graph.call_function(
exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {}
)
squeeze_node.meta = node.meta
return squeeze_node
else:
# 1D case: [OC, 1, K], target is [K, OC]
# Permute [0, 1, 2] -> [1, 2, 0] gives [1, K, OC]
permute_indices = [1, 2, 0]
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)
permute_node.meta = node.meta
permute_indices = [2, 3, 0, 1]
permute_node = graph.call_function(
exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {}
)
permute_node.meta = node.meta

# Squeeze the first dimension (which has size 1)
squeeze_node = graph.call_function(
exir_ops.edge.aten.squeeze_copy.dim, (permute_node, 0), {}
)
squeeze_node.meta = node.meta
return squeeze_node
# Squeeze the last dimension (which has size 1)
squeeze_node = graph.call_function(
exir_ops.edge.aten.squeeze_copy.dim, (permute_node, -1), {}
)
squeeze_node.meta = node.meta
return squeeze_node

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
assert isinstance(node.target, EdgeOpOverload)
quantized_op = (
node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
)
quantized_op = node.target in {
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
}

# Check if already in NHWC layout
# Check if already in NHWC/NLC layout
if not quantized_op and len(node.args) == 8 and node.args[-1] is True:
return False

# Get input shape to determine if it's 1D or 2D
input_node = cast(torch.fx.Node, node.args[0])
input_shape = input_node.meta["val"].shape
is_2d = len(input_shape) == 4

# Determine the new op target
if quantized_op:
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
if is_2d:
new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
else:
new_op = exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor
else:
new_op = node.target

graph = node.graph

# Get input and weight nodes
input_node = cast(torch.fx.Node, node.args[0])
# Get weight node
weight_node = cast(torch.fx.Node, node.args[1])

# Check if this is a depthwise convolution (groups == input_channels)
# and weight is 4D with shape [OC, 1, KH, KW]
groups = cast(int, node.args[6])
input_shape = input_node.meta["val"].shape
weight_shape = weight_node.meta["val"].shape
input_channels = input_shape[1] # NCHW format, channels at index 1
# NCHW: also verify weight IC dim == 1.
depthwise = is_depthwise_conv(groups, input_channels) and weight_shape[1] == 1
is_2d = len(input_shape) == 4

# Insert transpose operations before the node
with graph.inserting_before(node):
# Convert input from NCHW to NHWC
Expand Down Expand Up @@ -1275,21 +1275,34 @@ class ReplaceConvWithIm2RowAndLinear(RemoveOrReplacePassInterface):
exir_ops.edge.cadence.conv1d.default: exir_ops.edge.aten.linear.default,
exir_ops.edge.cadence.conv2d.default: exir_ops.edge.aten.linear.default,
exir_ops.edge.cadence.conv3d.default: exir_ops.edge.aten.linear.default,
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor: exir_ops.edge.cadence.quantized_linear.per_tensor,
}

# Set of quantized conv ops
quantized_conv_ops: frozenset[EdgeOpOverload] = frozenset({
exir_ops.edge.cadence.quantized_conv1d_ncl.per_tensor,
exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor,
})

# Set of channel-last conv ops (NHWC for 2D, NLC for 1D)
channel_last_conv_ops: frozenset[EdgeOpOverload] = frozenset({
exir_ops.edge.cadence.quantized_conv1d_nlc.per_tensor,
exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor,
})

@property
def targets(self) -> list[EdgeOpOverload]:
return list(self.conv_op_to_linear_op.keys())

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the relevant args from convolution node.
assert isinstance(node.target, EdgeOpOverload)
quantized_op = (
node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor
or node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
)
quantized_op = node.target in self.quantized_conv_ops
assert (len(node.args) == 7 and not quantized_op) or (
len(node.args) >= 12 and quantized_op
), "Inconsistent args for convolution"
Expand Down Expand Up @@ -1325,13 +1338,9 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
out_shape = node.meta["val"].shape
assert None not in {weight_shape, out_shape}

# Determine if the convolution is NCHW or NHWC. The NHWC, i.e., the
# channel_last layout is specified by the channel_last arg of conv
# op, which is either the last argument (15th) or implicitely False
# if the op is quantized, or the last argument if not.
channel_last = (
node.target == exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor
)
# Determine if the convolution is NCHW or NHWC (for 2D) / NCL or NLC (for 1D).
# Channel-last layouts are NHWC for 2D and NLC for 1D.
channel_last = node.target in self.channel_last_conv_ops
# The weight tensor is [out_channels, in_channels, X] for NCHW layout,
# and [out_channels, X, in_channels] for NHWC layout. Here, X is the
# kernel_width for conv1d, and X = kernel_height * kernel_width for
Expand Down
55 changes: 55 additions & 0 deletions backends/cadence/aot/tests/test_quantizer_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,15 @@
[None, qconfig_A8W8.input_activation, qconfig_A8W8.weight],
),
# CadenceFusedConvReluQuantizer test cases
(
"fused_conv1d_relu_A8W8sym",
lambda self: self._build_conv1d_relu_graph(),
CadenceFusedConvReluQuantizer(),
torch.ops.aten.relu.default,
qconfig_A8W8sym.output_activation,
# For fused conv1d+relu: [input_activation, weight] from conv1d node
[qconfig_A8W8sym.input_activation, qconfig_A8W8sym.weight],
),
(
"fused_conv2d_relu_A8W8sym",
lambda self: self._build_conv2d_relu_graph(),
Expand Down Expand Up @@ -503,6 +512,52 @@ def _build_conv2d_relu_graph(

return gm, relu_nodes[0], conv2d_nodes[0]

def _build_conv1d_relu_graph(
self,
) -> tuple[torch.fx.GraphModule, torch.fx.Node, torch.fx.Node]:
"""Build a graph with a conv1d followed by relu (fused pattern).

Returns:
A tuple of (graph_module, relu_node, conv_node).
The relu_node is the target node where the annotation is placed.
The conv_node is the input source node whose args contain the quantized inputs.
"""
builder = GraphBuilder()
# Input shape: (batch, in_channels, length)
x = builder.placeholder("x", torch.randn(1, 3, 10))
# Weight shape: (out_channels, in_channels, kernel_size)
weight = builder.placeholder("weight", torch.randn(6, 3, 3))
conv1d = builder.call_operator(
op=torch.ops.aten.conv1d.default,
args=(x, weight),
meta=NodeMetadata(
{"source_fn_stack": [("conv1d", torch.ops.aten.conv1d.default)]}
),
)
relu = builder.call_operator(
op=torch.ops.aten.relu.default,
args=(conv1d,),
meta=NodeMetadata(
{"source_fn_stack": [("relu", torch.ops.aten.relu.default)]}
),
)
builder.output([relu])
gm = builder.get_graph_module()

relu_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.relu.default,
)
self.assertEqual(len(relu_nodes), 1, "Should find exactly one relu node")

conv1d_nodes = gm.graph.find_nodes(
op="call_function",
target=torch.ops.aten.conv1d.default,
)
self.assertEqual(len(conv1d_nodes), 1, "Should find exactly one conv1d node")

return gm, relu_nodes[0], conv1d_nodes[0]

@parameterized.expand(QUANTIZER_ANNOTATION_TEST_CASES)
def test_quantizer_annotation(
self,
Expand Down
Loading
Loading