diff --git a/backends/cadence/aot/compiler_funcs.py b/backends/cadence/aot/compiler_funcs.py index 02dcde7fd39..cec3cb7d016 100644 --- a/backends/cadence/aot/compiler_funcs.py +++ b/backends/cadence/aot/compiler_funcs.py @@ -14,6 +14,7 @@ import torch from torch._inductor.decomposition import remove_decompositions from torch.fx import GraphModule +from torch.fx.passes.infra.pass_base import PassBase, PassResult from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -607,3 +608,32 @@ def sink_input_dequant_through_transparent_ops( graph_module.recompile() return modified + + +class QuantFusionPass(PassBase): + """ + Iterates patterns, finds anchor ops in the converted graph, and calls + pattern.fuse() to replace dq-op-q subgraphs with fused ops. + """ + + def __init__(self, patterns: Sequence[object]) -> None: + super().__init__() + self.patterns = patterns + + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + changed = False + for pattern in self.patterns: + pattern_changed = False + for target in pattern.anchor_ops(): # pyre-ignore[16] + for node in graph_module.graph.find_nodes( + op="call_function", target=target + ): + result = pattern.fuse(graph_module, node) # pyre-ignore[16] + if result is not None: + changed = True + pattern_changed = True + if pattern_changed: + graph_module.graph.eliminate_dead_code() + if changed: + graph_module.recompile() + return PassResult(graph_module, changed) diff --git a/backends/cadence/aot/quantizer/BUCK b/backends/cadence/aot/quantizer/BUCK index 34fec2556f8..956bf700bd7 100644 --- a/backends/cadence/aot/quantizer/BUCK +++ b/backends/cadence/aot/quantizer/BUCK @@ -14,6 +14,21 @@ fbcode_target(_kind = runtime.python_library, ], ) +fbcode_target(_kind = runtime.python_library, + name = "pattern_utils", + srcs = [ + "pattern_utils.py", + ], + typing = True, + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/cadence/aot:compiler_utils", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/backends/cadence/aot:utils", + ], +) + fbcode_target(_kind = runtime.python_library, name = "patterns", srcs = [ @@ -21,8 +36,10 @@ fbcode_target(_kind = runtime.python_library, ], typing = True, deps = [ + ":pattern_utils", ":utils", "//caffe2:torch", + "//executorch/backends/cadence/aot:pass_utils", ], ) diff --git a/backends/cadence/aot/quantizer/pattern_utils.py b/backends/cadence/aot/quantizer/pattern_utils.py new file mode 100644 index 00000000000..7f60b23d278 --- /dev/null +++ b/backends/cadence/aot/quantizer/pattern_utils.py @@ -0,0 +1,194 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import operator +from typing import Any, Optional + +import torch +from executorch.backends.cadence.aot.pass_utils import get_arg +from executorch.backends.cadence.aot.quantizer.utils import ( + create_zero_bias_int32, + quantize_tensor_multiplier, +) +from executorch.backends.cadence.aot.utils import is_depthwise_conv +from torch import fx +from torch._ops import OpOverload + +DQ_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.dequantize_per_tensor.default +Q_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.quantize_per_tensor.default + + +def find_quant_user(node: fx.Node) -> Optional[fx.Node]: + """Find the first quantize_per_tensor user of ``node``, traversing through getitem.""" + users = list(node.users) + if not users: + return None + user = users[0] + if user.target is operator.getitem: + if len(user.args) >= 2 and user.args[1] == 0: + users = list(user.users) + if not users: + return None + user = users[0] + else: + return None + if user.target == Q_PER_TENSOR: + return user + return None + + +def replace_with_op( + gm: fx.GraphModule, + insert_after: fx.Node, + replacement_op: OpOverload, + args: tuple[Any, ...], + kwargs: dict[str, Any], + node_to_replace: fx.Node, +) -> fx.Node: + """Insert ``replacement_op`` after ``insert_after`` and replace all uses of + ``node_to_replace`` with the new node.""" + with gm.graph.inserting_after(insert_after): + new_node = gm.graph.call_function(replacement_op, args, kwargs) + new_node.meta = node_to_replace.meta + node_to_replace.replace_all_uses_with(new_node) + return new_node + + +def fuse_conv( + pattern: object, + gm: fx.GraphModule, + conv_node: fx.Node, + dq_input: fx.Node, + dq_weight: fx.Node, + quant_node: fx.Node, +) -> fx.Node: + """Fuse a dq→conv→q chain into a single quantized conv op.""" + dq_bias = None + if len(conv_node.args) > 2 and conv_node.args[2] is not None: + bias_arg = conv_node.args[2] + assert isinstance(bias_arg, fx.Node) + dq_bias = bias_arg if bias_arg.target == DQ_PER_TENSOR else None + weight_scale = get_arg(dq_weight, "scale", float) + input_scale = get_arg(dq_input, "scale", float) + # pyre-fixme[58] + bias_scale = input_scale * weight_scale + if dq_bias is not None: + bias_q = get_arg(dq_bias, "input", fx.Node) + else: + weight_node = get_arg(dq_weight, "input", fx.Node) + bias_q = create_zero_bias_int32( + gm, weight_node, bias_scale, insert_before=conv_node + ) + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = ( + get_arg(dq_input, "input", fx.Node), + get_arg(dq_weight, "input", fx.Node), + bias_q, + ) + groups = get_arg(conv_node, "groups", int) + kwargs = { + "stride": get_arg(conv_node, "stride", list[int]), + "padding": get_arg(conv_node, "padding", list[int]), + "dilation": get_arg(conv_node, "dilation", list[int]), + "groups": groups, + "input_zero_point": get_arg(dq_input, "zero_point", int), + "weight_zero_point": get_arg(dq_weight, "zero_point", int), + "bias_scale": bias_scale, + "out_scale": get_arg(quant_node, "scale", float), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + } + replacement_op = pattern.replacement_op() # pyre-ignore[16] + if replacement_op == torch.ops.cadence.quantized_conv1d_ncl.per_tensor: + input_node = get_arg(dq_input, "input", fx.Node) + in_channels = input_node.meta["val"].shape[1] + if is_depthwise_conv(groups, in_channels): + replacement_op = torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor + return replace_with_op(gm, conv_node, replacement_op, args, kwargs, quant_node) + + +def fuse_linear( + gm: fx.GraphModule, + dq_input: fx.Node, + dq_weight: fx.Node, + dq_bias: Optional[fx.Node], + quant_node: fx.Node, + op_node: fx.Node, + replacement_op: OpOverload, + weight_q: Optional[fx.Node] = None, +) -> fx.Node: + """Fuse a dq→linear→q chain into a single quantized linear op.""" + assert op_node.target in ( + torch.ops.aten.linear.default, + torch.ops.aten.addmm.default, + ), f"Expected linear/addmm, got {op_node.target}" + weight_scale = get_arg(dq_weight, "scale", float) + input_scale = get_arg(dq_input, "scale", float) + # pyre-fixme[58] + bias_scale = input_scale * weight_scale + requantize_scale = bias_scale / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + if dq_bias is not None: + bias_q = get_arg(dq_bias, "input", fx.Node) + else: + weight_node = get_arg(dq_weight, "input", fx.Node) + bias_q = create_zero_bias_int32( + gm, weight_node, bias_scale, insert_before=op_node + ) + final_weight = ( + weight_q if weight_q is not None else get_arg(dq_weight, "input", fx.Node) + ) + args = (get_arg(dq_input, "input", fx.Node), final_weight, bias_q) + kwargs = { + "src_zero_point": get_arg(dq_input, "zero_point", int), + "weight_zero_point": get_arg(dq_weight, "zero_point", int), + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "offset": None, + } + return replace_with_op(gm, op_node, replacement_op, args, kwargs, quant_node) + + +def fuse_matmul( + gm: fx.GraphModule, + anchor_node: fx.Node, + dq0: fx.Node, + dq1: fx.Node, + quant_node: fx.Node, + replacement_op: OpOverload, +) -> fx.Node: + """Fuse a dq→matmul→q chain into a single quantized matmul op.""" + assert anchor_node.target in ( + torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, + ), f"Expected bmm/matmul, got {anchor_node.target}" + scale0 = get_arg(dq0, "scale", float) + scale1 = get_arg(dq1, "scale", float) + # pyre-ignore[58] + requantize_scale = (scale0 * scale1) / get_arg(quant_node, "scale", float) + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = ( + get_arg(dq0, "input", fx.Node), + get_arg(dq0, "zero_point", int), + get_arg(dq1, "input", fx.Node), + get_arg(dq1, "zero_point", int), + None, + ) + kwargs = { + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + "out_zero_point": get_arg(quant_node, "zero_point", int), + "transposed": False, + } + return replace_with_op(gm, anchor_node, replacement_op, args, kwargs, quant_node) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index 54c01227d07..642f2362f1e 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -9,11 +9,25 @@ import operator from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch -from executorch.backends.cadence.aot.quantizer.utils import get_bias_qparams - +from executorch.backends.cadence.aot.compiler_utils import get_shape +from executorch.backends.cadence.aot.pass_utils import get_arg +from executorch.backends.cadence.aot.quantizer.pattern_utils import ( + DQ_PER_TENSOR, + find_quant_user, + fuse_conv, + fuse_linear, + fuse_matmul, + replace_with_op, +) +from executorch.backends.cadence.aot.quantizer.utils import ( + check_out_zero_point_is_min_range, + copy_node_metadata, + get_bias_qparams, + quantize_tensor_multiplier, +) from torch import fx from torch._ops import OpOverload from torchao.quantization.pt2e.quantizer import ( @@ -79,6 +93,22 @@ def replacement_op(self) -> OpOverload: """ pass + def anchor_ops(self) -> tuple[OpOverload, ...]: + return tuple(self.partition_types()) + + def fuse( + self, + gm: fx.GraphModule, + anchor_node: fx.Node, + ) -> Optional[fx.Node]: + """Replace the dq→op→q subgraph around ``anchor_node`` with a fused op. + + Called by ``QuantFusionPass`` for each node matching ``anchor_ops()``. + Returns the new fused node on success, or ``None`` to skip this match. + Subclasses override to implement pattern-specific fusion logic. + """ + return None + class AddmmPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -115,6 +145,46 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + # addmm(bias, input, weight) + dq_input = anchor_node.args[1] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[2] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + bias_arg = anchor_node.args[0] + dq_bias = ( + bias_arg + if isinstance(bias_arg, fx.Node) and bias_arg.target == DQ_PER_TENSOR + else None + ) + weight_q = dq_weight.args[0] + assert isinstance(weight_q, fx.Node) + with gm.graph.inserting_before(anchor_node): + transposed = gm.graph.call_function( + torch.ops.aten.transpose.int, (weight_q, 0, 1) + ) + assert "val" in weight_q.meta + original_val = weight_q.meta["val"] + assert original_val.fake_mode is not None + with original_val.fake_mode: + transposed.meta["val"] = torch.ops.aten.transpose.int(original_val, 0, 1) + copy_node_metadata(transposed, weight_q) + return fuse_linear( + gm, + dq_input, + dq_weight, + dq_bias, + quant_node, + anchor_node, + self.replacement_op(), + weight_q=transposed, + ) + class AddPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -153,6 +223,37 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + if not ( + isinstance(anchor_node.args[0], fx.Node) + and isinstance(anchor_node.args[1], fx.Node) + ): + return None + if len(anchor_node.kwargs) > 0: + return None + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + args = ( + dq0.args[0], + dq0.args[1], + dq0.args[2], + dq1.args[0], + dq1.args[1], + dq1.args[2], + quant_node.args[1], + quant_node.args[2], + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + # This is a base class for Add+ReLU fusion, since it can be used with two different relu aten ops class AddReluBasePattern(QuantizationPattern): @@ -196,6 +297,48 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_add.per_tensor + def anchor_ops(self) -> tuple[OpOverload, ...]: + return (torch.ops.aten.add.Tensor,) + + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + add_users = list(anchor_node.users) + if len(add_users) != 1: + return None + relu_node = add_users[0] + if relu_node.target != self.partition_types()[1]: + return None + if not ( + isinstance(anchor_node.args[0], fx.Node) + and isinstance(anchor_node.args[1], fx.Node) + ): + return None + if len(anchor_node.kwargs) > 0: + return None + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(relu_node) + if quant_node is None: + return None + # pyre-ignore[6]: Argument -> int/dtype narrowing + check_out_zero_point_is_min_range(quant_node.args[2], quant_node.args[5]) + args = ( + dq0.args[0], + dq0.args[1], + dq0.args[2], + dq1.args[0], + dq1.args[1], + dq1.args[2], + quant_node.args[1], + quant_node.args[2], + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + # Add + regular relu op fusion class AddReluPattern0(AddReluBasePattern): @@ -234,6 +377,18 @@ def replacement_op(self) -> OpOverload: # we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op()) + class CatPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -283,6 +438,28 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.aten.cat.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + cat_inputs = anchor_node.args[0] + if not isinstance(cat_inputs, (list, tuple)) or not cat_inputs: + return None + inputs_q = [] + for inp in cat_inputs: + if not isinstance(inp, fx.Node) or inp.target != DQ_PER_TENSOR: + return None + dq = inp + if dq is None: + return None + inputs_q.append(dq.args[0]) + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + dim = get_arg(anchor_node, "dim", int) + args = (inputs_q,) + kwargs = {"dim": dim} + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + class Conv1dPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -325,6 +502,18 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv1d_ncl.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + class Conv2dPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -367,6 +556,18 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + class LayerNormPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -405,6 +606,73 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_layer_norm.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + assert isinstance(dq_input.args[1], float) and isinstance( + dq_input.args[2], int + ), "per-channel quantization is not supported for layer norm" + scale = dq_input.args[1] + zero_point = dq_input.args[2] + normalized_shape = anchor_node.args[1] + assert isinstance(normalized_shape, list) + weight = ( + anchor_node.args[2] + if len(anchor_node.args) > 2 and anchor_node.args[2] + else None + ) + bias = ( + anchor_node.args[3] + if len(anchor_node.args) > 3 and anchor_node.args[3] + else None + ) + input_q = dq_input.args[0] + assert isinstance(input_q, fx.Node) + if not weight: + with gm.graph.inserting_before(anchor_node): + weight = gm.graph.call_function( + torch.ops.aten.full.default, + (normalized_shape, 1), + {"dtype": torch.float32}, + ) + assert "val" in input_q.meta + fake_mode = input_q.meta["val"].fake_mode + assert fake_mode is not None + with fake_mode: + weight.meta["val"] = torch.full( + normalized_shape, 1, dtype=torch.float32 + ) + copy_node_metadata(weight, input_q) + if not bias: + with gm.graph.inserting_before(anchor_node): + bias = gm.graph.call_function( + torch.ops.aten.full.default, + (normalized_shape, 0), + {"dtype": torch.float32}, + ) + assert "val" in input_q.meta + fake_mode = input_q.meta["val"].fake_mode + assert fake_mode is not None + with fake_mode: + bias.meta["val"] = torch.full(normalized_shape, 0, dtype=torch.float32) + copy_node_metadata(bias, input_q) + args = (input_q, scale, zero_point) + kwargs = { + "normalized_shape": normalized_shape, + "weight": weight, + "bias": bias, + "eps": 1e-05, + "output_scale": quant_node.args[1], + "output_zero_point": quant_node.args[2], + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + class LinearPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -447,6 +715,31 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_linear.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + dq_weight = anchor_node.args[1] + if not isinstance(dq_weight, fx.Node) or dq_weight.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + dq_bias: Optional[fx.Node] = None + if len(anchor_node.args) > 2: + bias_arg = anchor_node.args[2] + if isinstance(bias_arg, fx.Node) and bias_arg.target == DQ_PER_TENSOR: + dq_bias = bias_arg + return fuse_linear( + gm, + dq_input, + dq_weight, + dq_bias, + quant_node, + anchor_node, + self.replacement_op(), + ) + class MatmulPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -472,6 +765,18 @@ def replacement_op(self) -> OpOverload: # TODO: T240804887 This is actually a per-tensor variant, we just need to change the name of the op return torch.ops.cadence.quantized_matmul.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq0 = anchor_node.args[0] + if not isinstance(dq0, fx.Node) or dq0.target != DQ_PER_TENSOR: + return None + dq1 = anchor_node.args[1] + if not isinstance(dq1, fx.Node) or dq1.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + return fuse_matmul(gm, anchor_node, dq0, dq1, quant_node, self.replacement_op()) + class MaxPool2dPattern(QuantizationPattern): """ @@ -530,6 +835,30 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_max_pool2d_nchw.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + kernel_size = get_arg(anchor_node, "kernel_size", Optional[list[int]]) or [1, 1] + stride = get_arg(anchor_node, "stride", Optional[list[int]]) or kernel_size + padding = get_arg(anchor_node, "padding", Optional[list[int]]) or [0, 0] + dilation = get_arg(anchor_node, "dilation", Optional[list[int]]) or [1, 1] + ceil_mode = get_arg(anchor_node, "ceil_mode", Optional[bool]) or False + args = (dq_input.args[0],) + kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + class MaxPool2dWithoutIndicesPattern(QuantizationPattern): """ @@ -569,6 +898,30 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_max_pool2d_nchw.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + kernel_size = get_arg(anchor_node, "kernel_size", Optional[list[int]]) or [1, 1] + stride = get_arg(anchor_node, "stride", Optional[list[int]]) or kernel_size + padding = get_arg(anchor_node, "padding", Optional[list[int]]) or [0, 0] + dilation = get_arg(anchor_node, "dilation", Optional[list[int]]) or [1, 1] + ceil_mode = get_arg(anchor_node, "ceil_mode", Optional[bool]) or False + args = (dq_input.args[0],) + kwargs = { + "kernel_size": kernel_size, + "stride": stride, + "padding": padding, + "dilation": dilation, + "ceil_mode": ceil_mode, + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + # This is a base class for ReLU @@ -598,6 +951,29 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_relu.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + input_scale = dq_input.args[1] + # pyre-fixme[58] + requantize_scale = input_scale / quant_node.args[1] + requantize_scale_t = torch.tensor([requantize_scale]) + out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t) + args = (dq_input.args[0],) + kwargs = { + "X_zero_point": dq_input.args[2], + "out_zero_point": quant_node.args[2], + "out_multiplier": out_multiplier[0].item(), + "out_shift": out_shift[0].item(), + } + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, kwargs, quant_node + ) + # Regular relu op class ReluPattern0(ReluBasePattern): @@ -657,6 +1033,37 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_conv2d_nchw.per_tensor + def anchor_ops(self) -> tuple[OpOverload, ...]: + return (self.partition_types()[0],) + + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + conv_users = list(anchor_node.users) + if len(conv_users) != 1: + return None + relu_node = conv_users[0] + if relu_node.target != self.partition_types()[1]: + return None + _arg0 = anchor_node.args[0] + dq_input = ( + _arg0 + if isinstance(_arg0, fx.Node) and _arg0.target == DQ_PER_TENSOR + else None + ) + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + if dq_input is None or dq_weight is None: + return None + quant_node = find_quant_user(relu_node) + if quant_node is None: + return None + # pyre-ignore[6]: Argument -> int/dtype narrowing + check_out_zero_point_is_min_range(quant_node.args[2], quant_node.args[5]) + return fuse_conv(self, gm, anchor_node, dq_input, dq_weight, quant_node) + # Conv1d + regular relu op fusion class Conv1dReluPattern0(ConvReluBasePattern): @@ -711,6 +1118,52 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_softmax.per_tensor + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + dq_input = anchor_node.args[0] + if not isinstance(dq_input, fx.Node) or dq_input.target != DQ_PER_TENSOR: + return None + quant_node = find_quant_user(anchor_node) + if quant_node is None: + return None + input_q = dq_input.args[0] + assert isinstance(input_q, fx.Node) + quant_input = quant_node.args[0] + assert isinstance(quant_input, fx.Node) + mask_shape = get_shape(gm, quant_input) + mask_shape = list(mask_shape) if mask_shape else [] + mask_shape[-1] = mask_shape[-1] // 16 + with gm.graph.inserting_before(anchor_node): + mask_tensor = gm.graph.call_function( + torch.ops.aten.full.default, (mask_shape, 0.0), {"dtype": torch.int32} + ) + assert "val" in input_q.meta + fake_mode = input_q.meta["val"].fake_mode + assert fake_mode is not None + with fake_mode: + mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) + copy_node_metadata(mask_tensor, input_q) + with gm.graph.inserting_before(anchor_node): + pos_tensor = gm.graph.call_function( + torch.ops.aten.full.default, ([1], 0), {"dtype": torch.int64} + ) + with fake_mode: + pos_tensor.meta["val"] = torch.full([1], 0, dtype=torch.int64) + copy_node_metadata(pos_tensor, input_q) + args = ( + input_q, + mask_tensor, + get_arg(anchor_node, "dim", int), + 0, + pos_tensor, + dq_input.args[1], + dq_input.args[2], + quant_node.args[1], + quant_node.args[2], + ) + return replace_with_op( + gm, anchor_node, self.replacement_op(), args, {}, quant_node + ) + class MixedW8A32LinearPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -765,6 +1218,38 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_linear.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: + return None + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + _arg2 = anchor_node.args[2] + dq_bias = ( + _arg2 + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR + else None + ) + if dq_weight is None or dq_bias is None: + return None + input_node = anchor_node.args[0] + assert isinstance(input_node, fx.Node) + args = ( + input_node, + dq_weight.args[0], + dq_weight.args[1], + dq_bias.args[0], + dq_bias.args[1], + ) + with gm.graph.inserting_after(anchor_node): + fused = gm.graph.call_function(self.replacement_op(), args, {}) + fused.meta = anchor_node.meta + anchor_node.replace_all_uses_with(fused) + return fused + class MixedW8A32ConvPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -839,6 +1324,78 @@ def get_anchors( def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_conv.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + if len(anchor_node.args) != 3 or len(anchor_node.kwargs) > 0: + return None + _arg1 = anchor_node.args[1] + dq_weight = ( + _arg1 + if isinstance(_arg1, fx.Node) and _arg1.target == DQ_PER_TENSOR + else None + ) + _arg2 = anchor_node.args[2] + dq_bias = ( + _arg2 + if isinstance(_arg2, fx.Node) and _arg2.target == DQ_PER_TENSOR + else None + ) + if dq_weight is None or dq_bias is None: + return None + input_node = anchor_node.args[0] + assert isinstance(input_node, fx.Node) + assert get_arg(anchor_node, "stride", list[int]) == [1] + assert get_arg(anchor_node, "padding", list[int]) == [0] + assert get_arg(anchor_node, "dilation", list[int]) == [1] + assert get_arg(anchor_node, "groups", int) == 1 + weight_q = dq_weight.args[0] + assert isinstance(weight_q, fx.Node) + with gm.graph.inserting_before(anchor_node): + transposed_inputs = gm.graph.call_function( + torch.ops.aten.permute.default, (input_node, [0, 2, 1]) + ) + if "val" in input_node.meta: + original_val = input_node.meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) + else: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) + copy_node_metadata(transposed_inputs, input_node) + with gm.graph.inserting_before(anchor_node): + transposed_weights = gm.graph.call_function( + torch.ops.aten.permute.default, (weight_q, [2, 0, 1]) + ) + if "val" in weight_q.meta: + original_val = weight_q.meta["val"] + fake_mode = original_val.fake_mode + if fake_mode is not None: + with fake_mode: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) + else: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) + copy_node_metadata(transposed_weights, weight_q) + args = ( + transposed_inputs, + transposed_weights, + dq_weight.args[1], + dq_bias.args[0], + dq_bias.args[1], + ) + with gm.graph.inserting_after(anchor_node): + fused = gm.graph.call_function(self.replacement_op(), args, {}) + fused.meta = anchor_node.meta + anchor_node.replace_all_uses_with(fused) + return fused + class MixedW8A32GruPattern(QuantizationPattern): def partition_types(self) -> List[OpOverload]: @@ -911,6 +1468,43 @@ def __init__(self, args, meta): def replacement_op(self) -> OpOverload: return torch.ops.cadence.quantized_w8a32_gru.default + def fuse(self, gm: fx.GraphModule, anchor_node: fx.Node) -> Optional[fx.Node]: + if len(anchor_node.kwargs) > 0: + return None + params = anchor_node.args[2] + if not isinstance(params, (list, tuple)) or len(params) < 4: + return None + dq_w_ih = params[0] + if not isinstance(dq_w_ih, fx.Node) or dq_w_ih.target != DQ_PER_TENSOR: + return None + dq_w_hh = params[1] + if not isinstance(dq_w_hh, fx.Node) or dq_w_hh.target != DQ_PER_TENSOR: + return None + dq_b_ih = params[2] + if not isinstance(dq_b_ih, fx.Node) or dq_b_ih.target != DQ_PER_TENSOR: + return None + dq_b_hh = params[3] + if not isinstance(dq_b_hh, fx.Node) or dq_b_hh.target != DQ_PER_TENSOR: + return None + input_node = anchor_node.args[0] + hidden_node = anchor_node.args[1] + args = ( + input_node, + hidden_node, + dq_w_ih.args[0], + dq_w_ih.args[1], + dq_w_hh.args[0], + dq_w_hh.args[1], + dq_b_ih.args[0], + dq_b_ih.args[1], + dq_b_hh.args[0], + ) + with gm.graph.inserting_after(anchor_node): + fused = gm.graph.call_function(self.replacement_op(), args, {}) + fused.meta = anchor_node.meta + anchor_node.replace_all_uses_with(fused) + return fused + class RmsNormPattern(QuantizationPattern): """Pattern that preserves rms_norm from decomposition without matching anything.""" diff --git a/backends/cadence/aot/quantizer/utils.py b/backends/cadence/aot/quantizer/utils.py index 51182a4ce92..6b13809ab68 100644 --- a/backends/cadence/aot/quantizer/utils.py +++ b/backends/cadence/aot/quantizer/utils.py @@ -6,6 +6,7 @@ # pyre-unsafe +import contextlib import itertools from collections import OrderedDict from math import frexp, isclose, trunc @@ -116,6 +117,7 @@ def create_zero_bias_int32( graph_module: GraphModule, weight_node: fx.Node, bias_scale: float, + insert_before: fx.Node | None = None, ) -> fx.Node: """ Creates a zero bias tensor with the shape of weight[0] @@ -141,11 +143,17 @@ def create_zero_bias_int32( weight_shape = list(attr_node.shape) bias_shape = weight_shape[0] - new_node = graph_module.graph.call_function( - torch.ops.aten.full.default, - ([bias_shape], 0.0), - {"dtype": torch.int32}, + ctx = ( + graph_module.graph.inserting_before(insert_before) + if insert_before is not None + else contextlib.nullcontext() ) + with ctx: + new_node = graph_module.graph.call_function( + torch.ops.aten.full.default, + ([bias_shape], 0.0), + {"dtype": torch.int32}, + ) if "val" in weight_node.meta: fake_mode = weight_node.meta["val"].fake_mode diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 4b60feb2121..f3b70b18cb8 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -162,14 +162,25 @@ def targets(self) -> list[EdgeOpOverload]: def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: ns = exir_ops.edge if isinstance(node.target, EdgeOpOverload) else torch.ops + out_dtype = node.kwargs.get("out_dtype") + kwargs = {k: v for k, v in node.kwargs.items() if k != "out_dtype"} with node.graph.inserting_before(node): new_node = node.graph.call_function( ns.cadence.dequantize_per_tensor.default, args=node.args, - kwargs=node.kwargs, + kwargs=kwargs, ) new_node.meta = node.meta - node.replace_all_uses_with(new_node) + if out_dtype is not None: + with node.graph.inserting_after(new_node): + cast_node = node.graph.call_function( + torch.ops.aten.to.dtype, + args=(new_node, out_dtype), + ) + cast_node.meta = node.meta.copy() + node.replace_all_uses_with(cast_node) + else: + node.replace_all_uses_with(new_node) return True