Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 51 additions & 55 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,50 +69,46 @@ def contains_placeholder_or_param(nodes: Iterable[torch.fx.Node]) -> bool:


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceLogicalNotBooleanWhereWithWherePass(ExportPass):
class ReplaceLogicalNotBooleanWhereWithWherePass(RemoveOrReplacePassInterface):
"""
A where op with a logical_not and a boolean tensor can be replaced
by a where op with flipped inputs and the initial boolean tensor.
"""

def replace_logical_nop_where_with_where(
self, graph_module: torch.fx.GraphModule
) -> None:
graph = graph_module.graph
for node in graph.nodes:
# We are only interested in where nodes
if node.target != exir_ops.edge.aten.where.self:
continue
@property
def targets(self) -> list[EdgeOpOverload]:
return [exir_ops.edge.aten.where.self]

# If the third arg is not a logical_not, bail.
if node.args[0].target != exir_ops.edge.aten.logical_not.default:
continue
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# If the first arg is not a logical_not, bail.
if not isinstance(node.args[0], torch.fx.Node):
return False

# Get the third arg node and its input
logical_not_node = node.args[0]
logical_not_input_node = logical_not_node.args[0]
logical_not_node = cast(torch.fx.Node, node.args[0])
if logical_not_node.target != exir_ops.edge.aten.logical_not.default:
return False

# If the logical_not input is not a boolean tensor, bail.
if logical_not_input_node.meta["val"].dtype != torch.bool:
continue
# Get the first arg node and its input
if not isinstance(logical_not_node.args[0], torch.fx.Node):
return False

# Replace the where op with another one, flipping the inputs and using the boolean
# tensor from logical_not.
with graph.inserting_before(node):
linear_node = graph.call_function(
exir_ops.edge.aten.where.self,
args=(logical_not_node.args[0], node.args[2], node.args[1]),
)
# Replace all the uses
node.replace_all_uses_with(linear_node)
logical_not_input_node = cast(torch.fx.Node, logical_not_node.args[0])

graph_module.recompile()
graph_module.graph.eliminate_dead_code()
# If the logical_not input is not a boolean tensor, bail.
if logical_not_input_node.meta["val"].dtype != torch.bool:
return False

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.replace_logical_nop_where_with_where(graph_module)
result = super().call(graph_module)
return result
# Replace the where op with another one, flipping the inputs and using the boolean
# tensor from logical_not.
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.where.self,
args=(logical_not_input_node, node.args[2], node.args[1]),
)
new_node.meta = node.meta
# Replace all the uses
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -197,39 +193,39 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceSqueezeAndUnsqueezeWithViewPass(ExportPass):
class ReplaceSqueezeAndUnsqueezeWithViewPass(RemoveOrReplacePassInterface):
"""
When the shape is static, replace squeeze_copy and unsqueeze_copy ops with
view_copy op
"""

def call_operator(
self,
op,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
meta: NodeMetadata,
) -> ProxyValue:
# Instead of testing EdgeOpOverload, test EdgeOpOverloadPacket,
# which allows us to cover all overloads.
if get_edge_overload_packet(op) not in {
exir_ops.edge.aten.squeeze_copy,
exir_ops.edge.aten.unsqueeze_copy,
}:
return super().call_operator(op, args, kwargs, meta)
@property
def targets(self) -> list[EdgeOpOverload]:
return [
exir_ops.edge.aten.squeeze_copy.default,
exir_ops.edge.aten.squeeze_copy.dim,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.unsqueeze_copy.default,
]

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the output tensor shape
out_shape = meta["val"].shape
out_shape = node.meta["val"].shape

# Bail out if any dim is not an int (dynamic shape)
for dim in list(out_shape):
if not isinstance(dim, int):
return super().call_operator(op, args, kwargs, meta)
return False

# Return a view op with the new shape
view_args = (args[0], list(out_shape))
return super().call_operator(
exir_ops.edge.aten.view_copy.default, view_args, kwargs, meta
)
# Replace with view op with the new shape
with node.graph.inserting_before(node):
new_node = node.graph.call_function(
exir_ops.edge.aten.view_copy.default,
args=(node.args[0], list(out_shape)),
)
new_node.meta = node.meta
node.replace_all_uses_with(new_node)
return True


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down
162 changes: 160 additions & 2 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ReplaceFunctionallyEquivalentOpTargets,
ReplaceIm2RowWithViewPass,
ReplaceLinearWithFullyConnectedOpPass,
ReplaceLogicalNotBooleanWhereWithWherePass,
ReplaceMatmulWithTransposedMatmulPass,
ReplaceMMWithAddMMPass,
ReplaceMulTensorWithMulAndFullOpsPass,
Expand Down Expand Up @@ -988,7 +989,12 @@ def test_replace_squeeze_with_view(
args=(x,),
)
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
result = cast(PassResult, p(original_gm))

# Assert: Verify the pass modified the graph
self.assertTrue(result.modified)
graph_after_passes = result.graph_module

self.assertIsNotNone(graph_after_passes)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
Expand Down Expand Up @@ -1023,7 +1029,12 @@ def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int) -> None:
args=(x, dim),
)
p = ReplaceSqueezeAndUnsqueezeWithViewPass()
graph_after_passes = cast(PassResult, p(original_gm)).graph_module
result = cast(PassResult, p(original_gm))

# Assert: Verify the pass modified the graph
self.assertTrue(result.modified)
graph_after_passes = result.graph_module

self.assertIsNotNone(graph_after_passes)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
Expand Down Expand Up @@ -2183,3 +2194,150 @@ def test_replace_quantized_embedding(
),
1,
)


class TestReplaceLogicalNotBooleanWhereWithWherePass(unittest.TestCase):
"""Tests for the ReplaceLogicalNotBooleanWhereWithWherePass."""

@torch.no_grad()
def test_replace_where_with_logical_not_boolean(self) -> None:
"""Test that where(logical_not(bool_cond), x, y) is replaced with where(bool_cond, y, x)."""
# Setup: Create a graph with where(logical_not(bool_cond), x, y)
builder = GraphBuilder()
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(4, 8))

# Create logical_not node
logical_not = builder.call_operator(
op=exir_ops.edge.aten.logical_not.default,
args=(bool_cond,),
)

# Create where node using logical_not
where_node = builder.call_operator(
op=exir_ops.edge.aten.where.self,
args=(logical_not, x, y),
)
builder.output([where_node])
original_gm = builder.get_graph_module()

# Execute: Apply the replacement pass
p = ReplaceLogicalNotBooleanWhereWithWherePass()
result = cast(PassResult, p(original_gm))

# Assert: Verify the pass modified the graph
self.assertTrue(result.modified)
graph_after_passes = result.graph_module

# Assert: Verify logical_not is removed (dead code elimination)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
0,
)

# Assert: Verify where node still exists
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
1,
)

# Assert: Verify the arguments are flipped (condition uses original bool_cond, x and y are swapped)
found_node = False
for node in graph_after_passes.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.where.self
):
found_node = True
# First arg should be the original bool_cond (not the logical_not)
self.assertEqual(node.args[0].name, "bool_cond")
# Second and third args should be swapped (y, x instead of x, y)
self.assertEqual(node.args[1].name, "y")
self.assertEqual(node.args[2].name, "x")
self.assertTrue(found_node)

@torch.no_grad()
def test_no_replacement_when_not_boolean_tensor(self) -> None:
"""Test that the pass does NOT apply when logical_not input is not a boolean tensor."""
# Setup: Create a graph with where(logical_not(float_tensor > 0), x, y)
# The logical_not input is not directly a boolean tensor
builder = GraphBuilder()
float_tensor = builder.placeholder("float_tensor", torch.randn(4, 8))
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(4, 8))

# Create a comparison that produces a boolean
gt_node = builder.call_operator(
op=exir_ops.edge.aten.gt.Scalar,
args=(float_tensor, 0.0),
)

# Create logical_not node using the comparison result
logical_not = builder.call_operator(
op=exir_ops.edge.aten.logical_not.default,
args=(gt_node,),
)

# Create where node
where_node = builder.call_operator(
op=exir_ops.edge.aten.where.self,
args=(logical_not, x, y),
)
builder.output([where_node])
original_gm = builder.get_graph_module()

# Execute: Apply the replacement pass
p = ReplaceLogicalNotBooleanWhereWithWherePass()
result = cast(PassResult, p(original_gm))

# Assert: Verify the pass modified the graph (gt_node is a boolean tensor)
# The pass SHOULD apply because gt.Scalar returns a boolean tensor
self.assertTrue(result.modified)
graph_after_passes = result.graph_module

# Assert: Verify logical_not is removed
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.logical_not.default),
0,
)

@torch.no_grad()
def test_no_replacement_without_logical_not(self) -> None:
"""Test that the pass does NOT apply when there's no logical_not."""
# Setup: Create a graph with where(bool_cond, x, y) without logical_not
builder = GraphBuilder()
bool_cond = builder.placeholder("bool_cond", torch.randn(4, 8) > 0)
x = builder.placeholder("x", torch.randn(4, 8))
y = builder.placeholder("y", torch.randn(4, 8))

# Create where node directly without logical_not
where_node = builder.call_operator(
op=exir_ops.edge.aten.where.self,
args=(bool_cond, x, y),
)
builder.output([where_node])
original_gm = builder.get_graph_module()

# Execute: Apply the replacement pass
p = ReplaceLogicalNotBooleanWhereWithWherePass()
result = cast(PassResult, p(original_gm))

# Assert: Verify the pass did NOT modify the graph
self.assertFalse(result.modified)
graph_after_passes = result.graph_module

# Assert: Verify where node still exists unchanged
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.where.self),
1,
)

# Assert: Verify the arguments are unchanged
found_node = False
for node in graph_after_passes.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.where.self
):
found_node = True
self.assertEqual(node.args[0].name, "bool_cond")
self.assertEqual(node.args[1].name, "x")
self.assertEqual(node.args[2].name, "y")
self.assertTrue(found_node)
Loading