diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 7a20a3f64b4..1aec147cd67 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -814,11 +814,61 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseMulIntoDequantPass(ExportPass): +class FuseMulScalarIntoDequantPass(ExportPass): """ - Looks for the pattern where atem.mul is multiplying the outputs of dequantize - and aten.full. If found, updates the dequant scale to reflect the multiplication - and removes the full and mul nodes. + Looks for the pattern where aten.mul.Scalar is multiplying the + outputs of dequantize. If found, updates the dequant scale + to reflect the multiplication and removes the mul node. + """ + + def attempt_fusion( + self, graph_module: torch.fx.GraphModule, node: torch.fx.Node + ) -> None: + if node.target not in { + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, + }: + return + + # ensure that the single user of dequant is aten.mul.Scalar + user = list(node.users.keys())[0] + if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar: + return + + # ensure that the other arg to mul is a node (i.e. not a constant) + if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node): + return + + new_deq_args = list(node.args) + assert isinstance(node.args[1], Number) + assert isinstance(user.args[1], Number) + # pyre-ignore[58]: Unsupported operand * + new_deq_args[1] = node.args[1] * user.args[1] + + logging.debug( + f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}" + ) + + user.replace_all_uses_with(node) + node.args = tuple(new_deq_args) + + graph_module.graph.erase_node(user) + + graph_module.recompile() + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + self.attempt_fusion(graph_module, node) + result = super().call(graph_module) + return result + + +@register_cadence_pass(CadencePassAttribute(opt_level=1)) +class FuseMulTensorIntoDequantPass(ExportPass): + """ + Looks for the pattern where aten.mul is multiplying the outputs of dequantize + and aten.full, or vice versa. If found, updates the dequant scale to reflect + the multiplication and removes the full and mul nodes. """ def attempt_fusion( @@ -1017,7 +1067,8 @@ class CadenceFuseOpsInGraph: FuseCascadedTransposeOrPermuteOps, FuseCascadedViewOps, FuseQuantDequantToRequantizePass, - FuseMulIntoDequantPass, + FuseMulTensorIntoDequantPass, + FuseMulScalarIntoDequantPass, FuseFullThenReshapePass, FuseTransposeOrPermuteOpPairsPass, ] diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 4e267254488..3d9cadf741b 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -19,7 +19,8 @@ ) from executorch.backends.cadence.aot.fuse_ops import ( FuseFullThenReshapePass, - FuseMulIntoDequantPass, + FuseMulScalarIntoDequantPass, + FuseMulTensorIntoDequantPass, FuseQuantDequantToRequantizePass, FuseTransposeOrPermuteOpPairsPass, ) @@ -446,7 +447,7 @@ def forward(self, x): inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),) graph_module = export_to_edge(M(), inputs).exported_program().graph_module - graph_module = FuseMulIntoDequantPass()(graph_module).graph_module + graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module # verify that the mul and full ops were removed self.check_op_counts( @@ -467,6 +468,47 @@ def forward(self, x): deq_scale = node.args[1] self.assertEqual(deq_scale, 4.5) + def test_fuse_mul_scalar_into_dequant(self): + dequant_scale = 0.006 + mul_value = 0.3 + + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32)) + quant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 1, 0, -128, 127, torch.int8), + ) + dequant = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(quant, dequant_scale, 5, -128, 127, torch.int8), + ) + mul_scalar = builder.call_operator( + op=exir_ops.edge.aten.mul.Scalar, + args=(dequant, mul_value), + ) + builder.output(mul_scalar) + graph_module = builder.get_graph_module() + + graph_module = FuseMulScalarIntoDequantPass()(graph_module).graph_module + + # verify that the mul and full ops were removed + self.check_op_counts( + graph_module, + expected_op_counts={ + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, + exir_ops.edge.aten.mul.Scalar: 0, + }, + ) + + # verify that the dequant scale value was updated correctly + for node in graph_module.graph.nodes: + if ( + node.target + == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + deq_scale = node.args[1] + self.assertEqual(deq_scale, dequant_scale * mul_value) + def test_fuse_then_transpose_pass(self): # Create a graph with full -> transpose. builder = GraphBuilder()