Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions backends/cadence/aot/compiler_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
from torch._inductor.decomposition import remove_decompositions
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassBase, PassResult
from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e
from torchao.quantization.pt2e.quantizer import Quantizer

Expand Down Expand Up @@ -607,3 +608,32 @@ def sink_input_dequant_through_transparent_ops(
graph_module.recompile()

return modified


class QuantFusionPass(PassBase):
"""
Iterates patterns, finds anchor ops in the converted graph, and calls
pattern.fuse() to replace dq-op-q subgraphs with fused ops.
"""

def __init__(self, patterns: Sequence[object]) -> None:
super().__init__()
self.patterns = patterns

def call(self, graph_module: GraphModule) -> Optional[PassResult]:
changed = False
for pattern in self.patterns:
pattern_changed = False
for target in pattern.anchor_ops(): # pyre-ignore[16]
for node in graph_module.graph.find_nodes(
op="call_function", target=target
):
result = pattern.fuse(graph_module, node) # pyre-ignore[16]
if result is not None:
changed = True
pattern_changed = True
if pattern_changed:
graph_module.graph.eliminate_dead_code()
if changed:
graph_module.recompile()
return PassResult(graph_module, changed)
17 changes: 17 additions & 0 deletions backends/cadence/aot/quantizer/BUCK
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,32 @@ fbcode_target(_kind = runtime.python_library,
],
)

fbcode_target(_kind = runtime.python_library,
name = "pattern_utils",
srcs = [
"pattern_utils.py",
],
typing = True,
deps = [
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler_utils",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:utils",
],
)

fbcode_target(_kind = runtime.python_library,
name = "patterns",
srcs = [
"patterns.py",
],
typing = True,
deps = [
":pattern_utils",
":utils",
"//caffe2:torch",
"//executorch/backends/cadence/aot:pass_utils",
],
)

Expand Down
194 changes: 194 additions & 0 deletions backends/cadence/aot/quantizer/pattern_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import operator
from typing import Any, Optional

import torch
from executorch.backends.cadence.aot.pass_utils import get_arg
from executorch.backends.cadence.aot.quantizer.utils import (
create_zero_bias_int32,
quantize_tensor_multiplier,
)
from executorch.backends.cadence.aot.utils import is_depthwise_conv
from torch import fx
from torch._ops import OpOverload

DQ_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.dequantize_per_tensor.default
Q_PER_TENSOR: OpOverload = torch.ops.quantized_decomposed.quantize_per_tensor.default


def find_quant_user(node: fx.Node) -> Optional[fx.Node]:
"""Find the first quantize_per_tensor user of ``node``, traversing through getitem."""
users = list(node.users)
if not users:
return None
user = users[0]
if user.target is operator.getitem:
if len(user.args) >= 2 and user.args[1] == 0:
users = list(user.users)
if not users:
return None
user = users[0]
else:
return None
if user.target == Q_PER_TENSOR:
return user
return None


def replace_with_op(
gm: fx.GraphModule,
insert_after: fx.Node,
replacement_op: OpOverload,
args: tuple[Any, ...],
kwargs: dict[str, Any],
node_to_replace: fx.Node,
) -> fx.Node:
"""Insert ``replacement_op`` after ``insert_after`` and replace all uses of
``node_to_replace`` with the new node."""
with gm.graph.inserting_after(insert_after):
new_node = gm.graph.call_function(replacement_op, args, kwargs)
new_node.meta = node_to_replace.meta
node_to_replace.replace_all_uses_with(new_node)
return new_node


def fuse_conv(
pattern: object,
gm: fx.GraphModule,
conv_node: fx.Node,
dq_input: fx.Node,
dq_weight: fx.Node,
quant_node: fx.Node,
) -> fx.Node:
"""Fuse a dq→conv→q chain into a single quantized conv op."""
dq_bias = None
if len(conv_node.args) > 2 and conv_node.args[2] is not None:
bias_arg = conv_node.args[2]
assert isinstance(bias_arg, fx.Node)
dq_bias = bias_arg if bias_arg.target == DQ_PER_TENSOR else None
weight_scale = get_arg(dq_weight, "scale", float)
input_scale = get_arg(dq_input, "scale", float)
# pyre-fixme[58]
bias_scale = input_scale * weight_scale
if dq_bias is not None:
bias_q = get_arg(dq_bias, "input", fx.Node)
else:
weight_node = get_arg(dq_weight, "input", fx.Node)
bias_q = create_zero_bias_int32(
gm, weight_node, bias_scale, insert_before=conv_node
)
requantize_scale = bias_scale / get_arg(quant_node, "scale", float)
requantize_scale_t = torch.tensor([requantize_scale])
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
args = (
get_arg(dq_input, "input", fx.Node),
get_arg(dq_weight, "input", fx.Node),
bias_q,
)
groups = get_arg(conv_node, "groups", int)
kwargs = {
"stride": get_arg(conv_node, "stride", list[int]),
"padding": get_arg(conv_node, "padding", list[int]),
"dilation": get_arg(conv_node, "dilation", list[int]),
"groups": groups,
"input_zero_point": get_arg(dq_input, "zero_point", int),
"weight_zero_point": get_arg(dq_weight, "zero_point", int),
"bias_scale": bias_scale,
"out_scale": get_arg(quant_node, "scale", float),
"out_zero_point": get_arg(quant_node, "zero_point", int),
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
}
replacement_op = pattern.replacement_op() # pyre-ignore[16]
if replacement_op == torch.ops.cadence.quantized_conv1d_ncl.per_tensor:
input_node = get_arg(dq_input, "input", fx.Node)
in_channels = input_node.meta["val"].shape[1]
if is_depthwise_conv(groups, in_channels):
replacement_op = torch.ops.cadence.quantized_depthwise_conv1d_ncl.per_tensor
return replace_with_op(gm, conv_node, replacement_op, args, kwargs, quant_node)


def fuse_linear(
gm: fx.GraphModule,
dq_input: fx.Node,
dq_weight: fx.Node,
dq_bias: Optional[fx.Node],
quant_node: fx.Node,
op_node: fx.Node,
replacement_op: OpOverload,
weight_q: Optional[fx.Node] = None,
) -> fx.Node:
"""Fuse a dq→linear→q chain into a single quantized linear op."""
assert op_node.target in (
torch.ops.aten.linear.default,
torch.ops.aten.addmm.default,
), f"Expected linear/addmm, got {op_node.target}"
weight_scale = get_arg(dq_weight, "scale", float)
input_scale = get_arg(dq_input, "scale", float)
# pyre-fixme[58]
bias_scale = input_scale * weight_scale
requantize_scale = bias_scale / get_arg(quant_node, "scale", float)
requantize_scale_t = torch.tensor([requantize_scale])
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
if dq_bias is not None:
bias_q = get_arg(dq_bias, "input", fx.Node)
else:
weight_node = get_arg(dq_weight, "input", fx.Node)
bias_q = create_zero_bias_int32(
gm, weight_node, bias_scale, insert_before=op_node
)
final_weight = (
weight_q if weight_q is not None else get_arg(dq_weight, "input", fx.Node)
)
args = (get_arg(dq_input, "input", fx.Node), final_weight, bias_q)
kwargs = {
"src_zero_point": get_arg(dq_input, "zero_point", int),
"weight_zero_point": get_arg(dq_weight, "zero_point", int),
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
"out_zero_point": get_arg(quant_node, "zero_point", int),
"offset": None,
}
return replace_with_op(gm, op_node, replacement_op, args, kwargs, quant_node)


def fuse_matmul(
gm: fx.GraphModule,
anchor_node: fx.Node,
dq0: fx.Node,
dq1: fx.Node,
quant_node: fx.Node,
replacement_op: OpOverload,
) -> fx.Node:
"""Fuse a dq→matmul→q chain into a single quantized matmul op."""
assert anchor_node.target in (
torch.ops.aten.bmm.default,
torch.ops.aten.matmul.default,
), f"Expected bmm/matmul, got {anchor_node.target}"
scale0 = get_arg(dq0, "scale", float)
scale1 = get_arg(dq1, "scale", float)
# pyre-ignore[58]
requantize_scale = (scale0 * scale1) / get_arg(quant_node, "scale", float)
requantize_scale_t = torch.tensor([requantize_scale])
out_multiplier, out_shift = quantize_tensor_multiplier(requantize_scale_t)
args = (
get_arg(dq0, "input", fx.Node),
get_arg(dq0, "zero_point", int),
get_arg(dq1, "input", fx.Node),
get_arg(dq1, "zero_point", int),
None,
)
kwargs = {
"out_multiplier": out_multiplier[0].item(),
"out_shift": out_shift[0].item(),
"out_zero_point": get_arg(quant_node, "zero_point", int),
"transposed": False,
}
return replace_with_op(gm, anchor_node, replacement_op, args, kwargs, quant_node)
Loading
Loading