From c54e093f1a9154d63a38ef0502afe1b274575677 Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 12 Nov 2025 11:15:32 -0800 Subject: [PATCH 1/2] Update ReplaceLogicalNotBooleanWhereWithWherePass to use new pass interface (#15755) Summary: As titled, more efficient now and properly updates the modified bit. Also, this pass was missing tests, so added some for a variety of different cases. Differential Revision: D86782910 --- backends/cadence/aot/replace_ops.py | 62 ++++---- .../aot/tests/test_replace_ops_passes.py | 148 ++++++++++++++++++ 2 files changed, 177 insertions(+), 33 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index a05548a028c..e2a475e3392 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): +class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface): """ A where op with a logical_not and a boolean tensor can be replaced by a where op with flipped inputs and the initial boolean tensor. """ - def replace_logical_nop_where_with_where( - self, graph_module: torch.fx.GraphModule - ) -> None: - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in where nodes - if node.target != exir_ops.edge.aten.where.self: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] - # If the third arg is not a logical_not, bail. - if node.args[0].target != exir_ops.edge.aten.logical_not.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If the first arg is not a logical_not, bail. + if not isinstance(node.args[0], torch.fx.Node): + return False - # Get the third arg node and its input - logical_not_node = node.args[0] - logical_not_input_node = logical_not_node.args[0] + logical_not_node = cast(torch.fx.Node, node.args[0]) + if logical_not_node.target != exir_ops.edge.aten.logical_not.default: + return False - # If the logical_not input is not a boolean tensor, bail. - if logical_not_input_node.meta["val"].dtype != torch.bool: - continue + # Get the first arg node and its input + if not isinstance(logical_not_node.args[0], torch.fx.Node): + return False - # Replace the where op with another one, flipping the inputs and using the boolean - # tensor from logical_not. - with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.where.self, - args=(logical_not_node.args[0], node.args[2], node.args[1]), - ) - # Replace all the uses - node.replace_all_uses_with(linear_node) + logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0]) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_node.meta["val"].dtype != torch.bool: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_logical_nop_where_with_where(graph_module) - result = super().call(graph_module) - return result + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_input_node, node.args[2], node.args[1]), + ) + new_node.meta = node.meta + # Replace all the uses + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 4bcc8bf371c..ace6703170c 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -33,6 +33,7 @@ ReplaceFunctionallyEquivalentOpTargets, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, @@ -2183,3 +2184,150 @@ def test_replace_quantized_embedding( ), 1, ) + + +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" + + @torch.no_grad() + def test_replace_where_with_logical_not_boolean(self) -> None: + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" + # Setup: Create a graph with where(logical_not(bool_cond), x, y) + builder = GraphBuilder() + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create logical_not node + logical_not = builder.call_operator( + op=exir_ops.edge.aten.logical_not.default, + args=(bool_cond,), + ) + + # Create where node using logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(logical_not, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify logical_not is removed (dead code elimination) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), + 0, + ) + + # Assert: Verify where node still exists + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) + found_node = False + for node in graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ): + found_node = True + # First arg should be the original bool_cond (not the logical_not) + self.assertEqual(node.args[0].name, "bool_cond") + # Second and third args should be swapped (y, x instead of x, y) + self.assertEqual(node.args[1].name, "y") + self.assertEqual(node.args[2].name, "x") + self.assertTrue(found_node) + + @torch.no_grad() + def test_no_replacement_when_not_boolean_tensor(self) -> None: + """Test that the pass does NOT apply when logical_not input is not a boolean tensor.""" + # Setup: Create a graph with where(logical_not(float_tensor > 0), x, y) + # The logical_not input is not directly a boolean tensor + builder = GraphBuilder() + float_tensor = builder.placeholder("float_tensor", torch.randn(4, 8)) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create a comparison that produces a boolean + gt_node = builder.call_operator( + op=exir_ops.edge.aten.gt.Scalar, + args=(float_tensor, 0.0), + ) + + # Create logical_not node using the comparison result + logical_not = builder.call_operator( + op=exir_ops.edge.aten.logical_not.default, + args=(gt_node,), + ) + + # Create where node + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(logical_not, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph (gt_node is a boolean tensor) + # The pass SHOULD apply because gt.Scalar returns a boolean tensor + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify logical_not is removed + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), + 0, + ) + + @torch.no_grad() + def test_no_replacement_without_logical_not(self) -> None: + """Test that the pass does NOT apply when there's no logical_not.""" + # Setup: Create a graph with where(bool_cond, x, y) without logical_not + builder = GraphBuilder() + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create where node directly without logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(bool_cond, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify where node still exists unchanged + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + # Assert: Verify the arguments are unchanged + found_node = False + for node in graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ): + found_node = True + self.assertEqual(node.args[0].name, "bool_cond") + self.assertEqual(node.args[1].name, "x") + self.assertEqual(node.args[2].name, "y") + self.assertTrue(found_node) From c9f830f116dab6467c3bcc6142722582344a518b Mon Sep 17 00:00:00 2001 From: Andrew Grebenisan Date: Wed, 12 Nov 2025 11:15:32 -0800 Subject: [PATCH 2/2] Update ReplaceSqueezeAndUnsqueezeWithViewPass to use new pass interface (#15757) Summary: As titled, now it is more efficient and correctly updates the modified bit. Updated tests, too Differential Revision: D86785126 --- backends/cadence/aot/replace_ops.py | 44 +++++++++---------- .../aot/tests/test_replace_ops_passes.py | 14 +++++- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index e2a475e3392..df08b7eba65 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -193,39 +193,39 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass): +class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface): """ When the shape is static, replace squeeze_copy and unsqueeze_copy ops with view_copy op """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - # Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket, - # which allows us to cover all overloads. - if get_edge_overload_packet(op) not in { - exir_ops.edge.aten.squeeze_copy, - exir_ops.edge.aten.unsqueeze_copy, - }: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.squeeze_copy.default, + exir_ops.edge.aten.squeeze_copy.dim, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the output tensor shape - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape # Bail out if any dim is not an int (dynamic shape) for dim in list(out_shape): if not isinstance(dim, int): - return super().call_operator(op, args, kwargs, meta) + return False - # Return a view op with the new shape - view_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta - ) + # Replace with view op with the new shape + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index ace6703170c..375459a4e29 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -989,7 +989,12 @@ def test_replace_squeeze_with_view( args=(x,), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -1024,7 +1029,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None: args=(x, dim), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),