diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index c383adf4162..04efd1a7b3d 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -21,7 +21,6 @@ import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( get_shape, - get_tensor_from_attr, get_zero_point, is_node_with_op, quantize_tensor_multiplier, @@ -321,90 +320,106 @@ def call_operator(self, op, args, kwargs, meta): @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceAddMMWithLinearPass(ExportPass): +class ReplaceAddMMWithLinearPass(RemoveOrReplacePassInterface): """ This pass replaces addmm with linear op. + + AddMM computes: beta*bias + alpha*mm(mat1, mat2) + Linear computes: mat1 @ weight.T + bias + """ - def __init__(self): - super().__init__() - self.counter = 0 + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.addmm.default] - def replace_addmm_with_linear(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - for node in graph.nodes: - # We are only interested in admm nodes - if node.target != exir_ops.edge.aten.addmm.default: - continue + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # The addmm op has three concrete args: bias, mat1, mat2 + assert len(node.args) >= 3 + (bias, mat1, mat2) = node.args[0:3] - # The addmm op has three concrete args: input, mat1, mat2 - assert len(node.args) >= 3 - (bias, mat1, mat2) = node.args[0:3] - # The other two args are optional scale args - beta = node.kwargs.get("beta", 1.0) - alpha = node.kwargs.get("alpha", 1.0) - - # AddMM performs beta*bias + alpha*mm(mat1, mat2). We can convert - # it to linear op by multiplying beta to bias, and alpha to mat2.t(). - # However, the following two conditions must hold: - # a. If bias is not a param, then beta must be 1.0 - # b. If mat2 is not a param, then mat2 must be a transpose op. Also, - # the input to the transpose must be a param, or alpha must be 1.0. - fit_bias = is_node_with_op(bias, "get_attr") or beta == 1.0 - fit_mat2 = is_node_with_op(mat2, "get_attr") - transposed_mat2 = False - if ( - not fit_mat2 - and is_node_with_op(mat2, "call_function") - and mat2.target == exir_ops.edge.aten.transpose_copy.int - ): - mat2, transposed_mat2 = mat2.args[0], True - fit_mat2 = is_node_with_op(mat2, "get_attr") or alpha == 1.0 + # The other two args are optional scale args + beta = float(node.kwargs.get("beta", 1.0)) + alpha = float(node.kwargs.get("alpha", 1.0)) - if not fit_bias or not fit_mat2: - continue + bias, mat1, mat2 = cast( + tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node], + (bias, mat1, mat2), + ) + + graph = node.graph + + # Handle transpose: if mat2 is a transpose op, extract the original tensor + transposed_mat2 = False + if ( + mat2.op == "call_function" + and mat2.target == exir_ops.edge.aten.transpose_copy.int + ): + # mat2 is already transposed, so we use the input to the transpose + mat2 = cast(torch.fx.Node, mat2.args[0]) + transposed_mat2 = True + + # Multiply bias by beta if needed + if beta != 1.0: + # Create a scaled bias using element-wise multiplication in the graph + with graph.inserting_before(node): + beta_scalar = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], beta), + kwargs={"dtype": torch.float32}, + ) + beta_scalar.meta = node.meta + bias = graph.call_function( + exir_ops.edge.aten.mul.Tensor, + args=(bias, beta_scalar), + ) - # Multiply bias by beta - if beta != 1.0: - assert is_node_with_op(bias, "get_attr") - bias_tensor = get_tensor_from_attr(graph_module, bias) - assert isinstance(bias_tensor, torch.Tensor) - bias_tensor = beta * bias_tensor - with graph.inserting_before(node): - bias_name = f"_bias_addmm_to_linear_{self.counter}" - graph_module.register_buffer(bias_name, bias_tensor) - bias = graph.get_attr(bias_name) - - # Use associativity of scalar multiplication, and multiply alpha to mat2 - if is_node_with_op(mat2, "get_attr"): - mat2_tensor = get_tensor_from_attr(graph_module, mat2) - assert isinstance(mat2_tensor, torch.Tensor) - mat2_tensor = alpha * mat2_tensor - # transpose mat2 - mat2_tensor = mat2_tensor if transposed_mat2 else mat2_tensor.t() - with graph.inserting_before(node): - mat2_name = f"_mat2_addmm_to_linear_{self.counter}" - graph_module.register_buffer(mat2_name, mat2_tensor) - mat2 = graph.get_attr(mat2_name) - - # Construct the linear node - linear_args = (mat1, mat2, bias) + # Metadata copy important + bias.meta = node.meta + + # Multiply mat2 by alpha if needed + if alpha != 1.0: with graph.inserting_before(node): - linear_node = graph.call_function( - exir_ops.edge.aten.linear.default, args=linear_args + alpha_scalar = graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], alpha), + kwargs={"dtype": torch.float32}, + ) + alpha_scalar.meta = node.meta + mat2 = graph.call_function( + exir_ops.edge.aten.mul.Tensor, + args=(mat2, alpha_scalar), ) - linear_node.meta = node.meta - # Replace all the uses of the addmm op with linear op - node.replace_all_uses_with(linear_node) - self.counter += 1 - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # Metadata copy important + mat2.meta = node.meta - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.replace_addmm_with_linear(graph_module) - result = super().call(graph_module) - return result + # Transpose mat2 if it wasn't already transposed + if not transposed_mat2: + with graph.inserting_before(node): + mat2 = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(mat2, -1, -2), + ) + + # Metadata copy important + mat2.meta = node.meta + + # Construct the linear node: linear(input, weight, bias) + # linear computes: input @ weight.T + bias + linear_args = (mat1, mat2, bias) + with graph.inserting_before(node): + linear_node = graph.call_function( + exir_ops.edge.aten.linear.default, + args=linear_args, + ) + + # Metadata copy important + linear_node.meta = node.meta + + # Replace all uses of the addmm op with linear op + node.replace_all_uses_with(linear_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 573489f40b9..29b0308f137 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -65,7 +65,19 @@ def validate( modified: torch.fx.GraphModule, inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, ) -> None: + """Validate that two graph modules produce numerically equivalent outputs. + + Args: + original: The original graph module before the pass + modified: The modified graph module after the pass + inputs: Input tensors to run through both graphs + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + """ original.eval() modified.eval() with torch.no_grad(): @@ -74,10 +86,17 @@ def validate( flat_orig_out, _ = pytree.tree_flatten(orig_out) flat_mod_out, _ = pytree.tree_flatten(mod_out) - if not all(pytree.tree_map(torch.equal, flat_orig_out, flat_mod_out)): - raise AssertionError( - f"Pass validation failed with exact match for pass {pass_name}. Original graph {original} and modified graph {modified}" - ) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) class TestReplaceOpsPasses(unittest.TestCase): @@ -840,10 +859,10 @@ def test_replace_scalar_tensor_with_full( def test_replace_linear_with_fully_connected(self) -> None: shape, in_channels, out_channels = (1, 14), 14, 128 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randn([out_channels, in_channels], dtype=torch.float32) - ) + x_input = torch.randn(*shape, dtype=torch.float32) + weights_input = torch.randn([out_channels, in_channels], dtype=torch.float32) + x = builder.placeholder("x", x_input) + weights = builder.placeholder("weights", weights_input) permute_copy = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(weights, [1, 0]), @@ -854,14 +873,31 @@ def test_replace_linear_with_fully_connected(self) -> None: ) builder.output([mm]) original_gm = builder.get_graph_module() + gm = cast( PassResult, ReplacePermuteWithTransposePass()(original_gm) ).graph_module gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module - gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module + + gm_before_linear = copy.deepcopy(gm) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + self.assertTrue(pass_result.modified) + gm = pass_result.graph_module + + inputs = [x_input, weights_input] + validate(gm_before_linear, gm, inputs, "ReplaceAddMMWithLinearPass") + gm_before_fc = copy.deepcopy(gm) graph_after_passes = cast( PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm) ).graph_module + + validate( + gm_before_fc, + graph_after_passes, + inputs, + "ReplaceLinearWithFullyConnectedOpPass", + ) + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), @@ -878,21 +914,17 @@ def test_replace_linear_with_fully_connected(self) -> None: 0, ) - @expand( - [ - [(4, 16, 256), 256, 512, True], - [(7, 17, 12), 12, 34, False], - ] - ) + @expand([[1.0, 1.0], [2.0, 3.0]]) @torch.no_grad() - def test_replace_addmm_with_linear( - self, shape: Tuple[int], in_features: int, out_features: int, bias: bool - ) -> None: - M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 + def test_replace_addmm_with_linear(self, alpha: float, beta: float) -> None: + M, K, N = 14, 12, 10 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(N, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn([M, K], dtype=torch.float32)) - z = builder.placeholder("z", torch.randn([N, K], dtype=torch.float32)) + x_input = torch.randn(N, dtype=torch.float32) + y_input = torch.randn([M, K], dtype=torch.float32) + z_input = torch.randn([N, K], dtype=torch.float32) + x = builder.placeholder("x", x_input) + y = builder.placeholder("y", y_input) + z = builder.placeholder("z", z_input) permute_copy = builder.call_operator( op=exir_ops.edge.aten.permute_copy.default, args=(z, [1, 0]), @@ -904,12 +936,21 @@ def test_replace_addmm_with_linear( ) builder.output([addmm]) original_gm = builder.get_graph_module() + gm = cast( PassResult, ReplacePermuteWithTransposePass()(original_gm) ).graph_module - graph_after_passes = cast( - PassResult, ReplaceAddMMWithLinearPass()(gm) - ).graph_module + + gm_before_linear = copy.deepcopy(gm) + pass_result = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)) + self.assertTrue(pass_result.modified) + graph_after_passes = pass_result.graph_module + + inputs = [x_input, y_input, z_input] + validate( + gm_before_linear, graph_after_passes, inputs, "ReplaceAddMMWithLinearPass" + ) + self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default),