Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,11 +814,61 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseMulIntoDequantPass(ExportPass):
class FuseMulScalarIntoDequantPass(ExportPass):
"""
Looks for the pattern where atem.mul is multiplying the outputs of dequantize
and aten.full. If found, updates the dequant scale to reflect the multiplication
and removes the full and mul nodes.
Looks for the pattern where aten.mul.Scalar is multiplying the
outputs of dequantize. If found, updates the dequant scale
to reflect the multiplication and removes the mul node.
"""

def attempt_fusion(
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
) -> None:
if node.target not in {
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
}:
return

# ensure that the single user of dequant is aten.mul.Scalar
user = list(node.users.keys())[0]
if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar:
return

# ensure that the other arg to mul is a node (i.e. not a constant)
if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node):
return

new_deq_args = list(node.args)
assert isinstance(node.args[1], Number)
assert isinstance(user.args[1], Number)
# pyre-ignore[58]: Unsupported operand *
new_deq_args[1] = node.args[1] * user.args[1]

logging.debug(
f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}"
)

user.replace_all_uses_with(node)
node.args = tuple(new_deq_args)

graph_module.graph.erase_node(user)

graph_module.recompile()

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
self.attempt_fusion(graph_module, node)
result = super().call(graph_module)
return result


@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseMulTensorIntoDequantPass(ExportPass):
"""
Looks for the pattern where aten.mul is multiplying the outputs of dequantize
and aten.full, or vice versa. If found, updates the dequant scale to reflect
the multiplication and removes the full and mul nodes.
"""

def attempt_fusion(
Expand Down Expand Up @@ -1017,7 +1067,8 @@ class CadenceFuseOpsInGraph:
FuseCascadedTransposeOrPermuteOps,
FuseCascadedViewOps,
FuseQuantDequantToRequantizePass,
FuseMulIntoDequantPass,
FuseMulTensorIntoDequantPass,
FuseMulScalarIntoDequantPass,
FuseFullThenReshapePass,
FuseTransposeOrPermuteOpPairsPass,
]
46 changes: 44 additions & 2 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
)
from executorch.backends.cadence.aot.fuse_ops import (
FuseFullThenReshapePass,
FuseMulIntoDequantPass,
FuseMulScalarIntoDequantPass,
FuseMulTensorIntoDequantPass,
FuseQuantDequantToRequantizePass,
FuseTransposeOrPermuteOpPairsPass,
)
Expand Down Expand Up @@ -446,7 +447,7 @@ def forward(self, x):

inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
graph_module = FuseMulIntoDequantPass()(graph_module).graph_module
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module

# verify that the mul and full ops were removed
self.check_op_counts(
Expand All @@ -467,6 +468,47 @@ def forward(self, x):
deq_scale = node.args[1]
self.assertEqual(deq_scale, 4.5)

def test_fuse_mul_scalar_into_dequant(self):
dequant_scale = 0.006
mul_value = 0.3

builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
quant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
args=(x, 1, 0, -128, 127, torch.int8),
)
dequant = builder.call_operator(
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
args=(quant, dequant_scale, 5, -128, 127, torch.int8),
)
mul_scalar = builder.call_operator(
op=exir_ops.edge.aten.mul.Scalar,
args=(dequant, mul_value),
)
builder.output(mul_scalar)
graph_module = builder.get_graph_module()

graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module

# verify that the mul and full ops were removed
self.check_op_counts(
graph_module,
expected_op_counts={
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
exir_ops.edge.aten.mul.Scalar: 0,
},
)

# verify that the dequant scale value was updated correctly
for node in graph_module.graph.nodes:
if (
node.target
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
):
deq_scale = node.args[1]
self.assertEqual(deq_scale, dequant_scale * mul_value)

def test_fuse_then_transpose_pass(self):
# Create a graph with full -> transpose.
builder = GraphBuilder()
Expand Down
Loading