diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 5b910df5358..08503ade0cd 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass): +class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface): """ A where op with a logical_not and a boolean tensor can be replaced by a where op with flipped inputs and the initial boolean tensor. """ - def replace_logical_nop_where_with_where( - self, graph_module: torch.fx.GraphModule - ) -> None: - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in where nodes - if node.target != exir_ops.edge.aten.where.self: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] - # If the third arg is not a logical_not, bail. - if node.args[0].target != exir_ops.edge.aten.logical_not.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # If the first arg is not a logical_not, bail. + if not isinstance(node.args[0], torch.fx.Node): + return False - # Get the third arg node and its input - logical_not_node = node.args[0] - logical_not_input_node = logical_not_node.args[0] + logical_not_node = cast(torch.fx.Node, node.args[0]) + if logical_not_node.target != exir_ops.edge.aten.logical_not.default: + return False - # If the logical_not input is not a boolean tensor, bail. - if logical_not_input_node.meta["val"].dtype != torch.bool: - continue + # Get the first arg node and its input + if not isinstance(logical_not_node.args[0], torch.fx.Node): + return False - # Replace the where op with another one, flipping the inputs and using the boolean - # tensor from logical_not. - with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.where.self, - args=(logical_not_node.args[0], node.args[2], node.args[1]), - ) - # Replace all the uses - node.replace_all_uses_with(linear_node) + logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0]) - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # If the logical_not input is not a boolean tensor, bail. + if logical_not_input_node.meta["val"].dtype != torch.bool: + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_logical_nop_where_with_where(graph_module) - result = super().call(graph_module) - return result + # Replace the where op with another one, flipping the inputs and using the boolean + # tensor from logical_not. + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.where.self, + args=(logical_not_input_node, node.args[2], node.args[1]), + ) + new_node.meta = node.meta + # Replace all the uses + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 8bac42e6772..64214e75bfc 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -33,6 +33,7 @@ ReplaceFunctionallyEquivalentOpTargets, ReplaceIm2RowWithViewPass, ReplaceLinearWithFullyConnectedOpPass, + ReplaceLogicalNotBooleanWhereWithWherePass, ReplaceMatmulWithTransposedMatmulPass, ReplaceMMWithAddMMPass, ReplaceMulTensorWithMulAndFullOpsPass, @@ -2053,3 +2054,114 @@ def test_replace_quantized_embedding( ), 1, ) + + +class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase): + """Tests for the ReplaceLogicalNotBooleanWhereWithWherePass.""" + + def test_replace_where_with_logical_not_boolean(self) -> None: + """Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x).""" + # Setup: Create a graph with where(logical_not(bool_cond), x, y) + builder = GraphBuilder() + bool_cond_ = torch.randn(4, 8) > 0 + x_ = torch.randn(4, 8) + y_ = torch.randn(4, 8) + + bool_cond = builder.placeholder("bool_cond", bool_cond_) + x = builder.placeholder("x", x_) + y = builder.placeholder("y", y_) + + # Create logical_not node + logical_not = builder.call_operator( + op=exir_ops.edge.aten.logical_not.default, + args=(bool_cond,), + ) + + # Create where node using logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(logical_not, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Make a copy of the original graph before applying the pass + original_gm_copy = copy.deepcopy(original_gm) + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass modified the graph + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify logical_not is removed (dead code elimination) + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default), + 0, + ) + + # Assert: Verify where node still exists + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + # Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped) + where_nodes = list( + graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ) + ) + for node in where_nodes: + # First arg should be the original bool_cond (not the logical_not) + self.assertEqual(node.args[0].name, "bool_cond") + # Second and third args should be swapped (y, x instead of x, y) + self.assertEqual(node.args[1].name, "y") + self.assertEqual(node.args[2].name, "x") + + # Assert: Verify outputs match exactly by running both graphs + validate( + original_gm_copy, + graph_after_passes, + (bool_cond_, x_, y_), + "ReplaceLogicalNotBooleanWhereWithWherePass", + ) + + def test_no_replacement_without_logical_not(self) -> None: + """Test that the pass does NOT apply when there's no logical_not.""" + # Setup: Create a graph with where(bool_cond, x, y) without logical_not + builder = GraphBuilder() + bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0) + x = builder.placeholder("x", torch.randn(4, 8)) + y = builder.placeholder("y", torch.randn(4, 8)) + + # Create where node directly without logical_not + where_node = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(bool_cond, x, y), + ) + builder.output([where_node]) + original_gm = builder.get_graph_module() + + # Execute: Apply the replacement pass + p = ReplaceLogicalNotBooleanWhereWithWherePass() + result = cast(PassResult, p(original_gm)) + + # Assert: Verify the pass did NOT modify the graph + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Assert: Verify where node still exists unchanged + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.where.self), + 1, + ) + + for node in graph_after_passes.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.where.self + ): + self.assertEqual(node.args[0].name, "bool_cond") + self.assertEqual(node.args[1].name, "x") + self.assertEqual(node.args[2].name, "y")