diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index 25afdf7ee47..8bc3bb6acb3 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -26,7 +26,6 @@ get_cascaded_ops, get_permuted_dims, get_scale, - get_shape, get_tensor_from_attr, get_transposed_dims, get_zero_point, @@ -39,138 +38,149 @@ from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from executorch.exir.passes import dead_code_elimination_pass -from executorch.exir.passes.spec_prop_pass import SpecPropPass -from torch.fx.node import Argument +from executorch.exir.pass_base import ExportPass, PassResult from torch.nn.utils.fusion import fuse_conv_bn_weights @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseMMWithAdd(ExportPass): - # Return true if the node is a view node. +class FuseMMWithAdd(RemoveOrReplacePassInterface): + """ + Fuses mm -> add patterns into addmm. + + Given a graph of the form: + X = aten.mm(A, B) + Y = aten.add(X, C) + + Fuse X and Y into a single addmm node, after making sure that we can + broadcast C into X. + + There could be view node that takes a view of X, and feeds that + to the aten.add node: + X = aten.mm(A, B) + Y = X.view() + Z = aten.add(Y, C) + + Handle this case as well. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mm.default] - def is_view_node(self, node: torch.fx.Node): + def _is_view_node(self, node: torch.fx.Node) -> bool: + """Return true if the node is a view node.""" return node.target == exir_ops.edge.aten.view_copy.default - def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule): + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: """ - Given a graph of the form: - X = aten.mm(A, B) - Y = aten.add(X, C) - Fuse X and Y into a single addmm node, after making sure that we can - broadcast C into X. - There could be view node that takes a view of X, and feeds that - to the aten.add node: - X = aten.mm(A, B) - Y = X.view() - Z = aten.add(Y, C) - Handle this case as well. There are a few conditions for the - optimization to be valid: - 1. There should be a single user of the mm node, otherwise we cannot - remove it. - 2. There should be a single user of the add node, otherwise we cannot - fuse it with mm. + Try to fuse this mm node with a following add node. + + Returns True if fusion was performed, False otherwise. """ - graph = graph_module.graph - for node in graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.mm.default - ): - # We want to discover a chain of mm -> add, or mm -> view -> add. - # Only proceed if the current node is an mm node, and has only one - # user/successor. - if len(node.users) != 1: - continue + # We want to discover a chain of mm -> add, or mm -> view -> add. + # Only proceed if the current node is an mm node, and has only one + # user/successor. + if len(node.users) != 1: + return False - # Our addmm implementation computes (mat1 * mat2 + bias). So the - # addmm node in the graph should have three args. We collectively - # term mat1 and mat2 as mm_arg since they are the args of mm node, - # and bias as bias_arg. - # Since we already have discovered the mm node, we can get mat1 and - # mat2 by iterating over its args. So the current node is mm_arg. - # bias_arg can be found once we discover the add op that consumes - # the output of this mm node. Our next step is to find the add op. - mm_arg = node - user = list(node.users.keys())[0] - # intermediate_view is True when the fusion case is mm -> view -> add - intermediate_view = False - # Check if the single user of the mm node is a view op. If so, our - # graph could potentially have mm -> view -> add. We need to skip - # the view op, and check if its successor is the add op. One condition - # we need to verify is that the view op must have only a single user - # (the add op). - if self.is_view_node(user) and len(user.users) == 1: - # We want to maintain two invariants: - # (1) 'user' is a potential add op that will get fused with the - # mm node; - # (2) 'node' is the single predecessor of 'user' that is either - # the mm node, or the current view node; - # To maintain the invariant, we must mark this view op as 'node', - # and its single successor as 'user'. - intermediate_view = True - node = user - user = list(node.users.keys())[0] - - # Thanks to the invariant, we can now simply check if 'user' is an - # add op. We also want to ensure that the add op has only one user, - # otherwise we will get not be able to eliminate add op post fusion. - if user.target != exir_ops.edge.aten.add.Tensor or len(user.users) != 1: - continue + # Our addmm implementation computes (mat1 * mat2 + bias). So the + # addmm node in the graph should have three args. We collectively + # term mat1 and mat2 as mm_arg since they are the args of mm node, + # and bias as bias_arg. + # Since we already have discovered the mm node, we can get mat1 and + # mat2 by iterating over its args. So the current node is mm_arg. + # bias_arg can be found once we discover the add op that consumes + # the output of this mm node. Our next step is to find the add op. + mm_node = node + user = list(node.users.keys())[0] - # At this point, we have found an mm and an add node that we can - # fuse together. One arg of the add op is 'node' (thanks to the - # invariant). Find the other arg, and tag it as bias_arg. - assert len(user.args) == 2 - bias_arg = user.args[1] if user.args[0] == node else user.args[0] + # intermediate_view is True when the fusion case is mm -> view -> add + intermediate_view = False + view_node = None + + # Check if the single user of the mm node is a view op. If so, our + # graph could potentially have mm -> view -> add. We need to skip + # the view op, and check if its successor is the add op. One condition + # we need to verify is that the view op must have only a single user + # (the add op). + if self._is_view_node(user) and len(user.users) == 1: + # We want to maintain two invariants: + # (1) 'user' is a potential add op that will get fused with the + # mm node; + # (2) 'view_node' is the intermediate view node (if present) + intermediate_view = True + view_node = user + user = list(view_node.users.keys())[0] + + # Check if 'user' is an add op. We also want to ensure that the add op + # has only one user, otherwise we will not be able to eliminate add op + # post fusion. + if user.target != exir_ops.edge.aten.add.Tensor or len(user.users) != 1: + return False - # As a last check, make sure that we can broadcast the bias tensor - # to the output of mm. - mm_arg_shape = get_shape(graph_module, mm_arg) - bias_arg_shape = get_shape(graph_module, bias_arg) - if ( - mm_arg_shape is None - or bias_arg_shape is None - or not broadcastable(mm_arg_shape, bias_arg_shape) - or len(bias_arg_shape) > 2 - ): - continue + # At this point, we have found an mm and an add node that we can + # fuse together. One arg of the add op is either mm_node or view_node. + # Find the other arg, and tag it as bias_arg. + assert len(user.args) == 2 + add_input = view_node if intermediate_view else mm_node + bias_arg = user.args[1] if user.args[0] == add_input else user.args[0] - # Create a new addmm node, and insert it before add node. DCE should - # take care of removing the dead mm and/or view node. Based on the - # invariant, add node corresponds to 'user'. - with graph.inserting_before(user): - addmm_node = graph.call_function( - exir_ops.edge.aten.addmm.default, - args=(bias_arg, mm_arg.args[0], mm_arg.args[1]), - ) - # Replace all the uses of add node with addmm node, and remove add - # node from the graph. - user.replace_all_uses_with(addmm_node) - graph.erase_node(user) - - # As a finishing step, we want to ensure that the output of addmm is - # in the expected shape. For example, Let us assume the following - # input, where A, B are (4, 4) sized tensors, and C is (1, 4) sized - # tensor. - # T1 = torch.mm(A, B) - # T2 = T1.view((2, 2, 4)) - # return torch.add(T2, C) - # Here, the expectation is to get an output of size (2, 2, 4), which - # is the shape out of view node T2. However, the fused addmm will - # return an output of shape (4, 4). In a nutshell, we need to take - # care of the output shape when the following two conditions are met: - # 1. The fusion case is mm -> view -> add (i.e., intermediate_view - # is True) - # 2. The single successor of addmm is not a view op. + # As a last check, make sure that we can broadcast the bias tensor + # to the output of mm. + mm_shape = mm_node.meta.get("val") + bias_shape = bias_arg.meta.get("val") if isinstance(bias_arg, torch.fx.Node) else None + + if mm_shape is None or bias_shape is None: + return False + + mm_arg_shape = mm_shape.shape + bias_arg_shape = bias_shape.shape + + if ( + not broadcastable(mm_arg_shape, bias_arg_shape) + or len(bias_arg_shape) > 2 + ): + return False + + graph = node.graph + + # Create a new addmm node, and insert it before add node. + with graph.inserting_before(user): + addmm_node = graph.call_function( + exir_ops.edge.aten.addmm.default, + args=(bias_arg, mm_node.args[0], mm_node.args[1]), + ) + addmm_node.meta = user.meta + + # Replace all the uses of add node with addmm node, and remove add + # node from the graph. + user.replace_all_uses_with(addmm_node) + graph.erase_node(user) + + # As a finishing step, we want to ensure that the output of addmm is + # in the expected shape. For example, Let us assume the following + # input, where A, B are (4, 4) sized tensors, and C is (1, 4) sized + # tensor. + # T1 = torch.mm(A, B) + # T2 = T1.view((2, 2, 4)) + # return torch.add(T2, C) + # Here, the expectation is to get an output of size (2, 2, 4), which + # is the shape out of view node T2. However, the fused addmm will + # return an output of shape (4, 4). In a nutshell, we need to take + # care of the output shape when the following two conditions are met: + # 1. The fusion case is mm -> view -> add (i.e., intermediate_view + # is True) + # 2. The single successor of addmm is not a view op. + if len(addmm_node.users) > 0: addmm_user = list(addmm_node.users.keys())[0] - if intermediate_view and not self.is_view_node(addmm_user): + if intermediate_view and not self._is_view_node(addmm_user): + assert view_node is not None # Create a view node that correctly reshapes the output of addmm - # (i.e., 'user') to match the output shape of the add node. - # Thanks to our invariant, we know that the correct shape is held - # by 'node', which points to the view op in mm -> view -> add chain. - # We create its copy, and insert it just before addmm_user. + # to match the output shape of the add node. + # The correct shape is held by 'view_node', which points to the + # view op in mm -> view -> add chain. with graph.inserting_before(addmm_user): - view_copy_node = graph_module.graph.node_copy(node) + view_copy_node = graph.node_copy(view_node) # Any uses of addmm are replaced with this view_copy node. addmm_node.replace_all_uses_with(view_copy_node) # Now we massage the args of the view_copy node, so that it takes @@ -179,15 +189,7 @@ def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule): view_args[0] = addmm_node view_copy_node.args = tuple(view_args) - graph_module.recompile() - - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - # Compute the spec prop pass before we begin the fusion pipeline - result = SpecPropPass()(graph_module) - assert result is not None - self.fuse_mm_with_add(result.graph_module) - result = super().call(result.graph_module) - return result + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -523,32 +525,39 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseCascadedViewOps(ExportPass): +class FuseCascadedViewOps(RemoveOrReplacePassInterface): """ Fuse a cascaded chain of view ops """ - def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule): - view_target = exir_ops.edge.aten.view_copy.default - for view_node in graph_module.graph.find_nodes( - op="call_function", target=view_target, sort=True - ): - input_view = view_node.args[0] - if input_view.op != "call_function" or input_view.target != view_target: - continue + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.view_copy.default] - view_node.replace_input_with(input_view, input_view.args[0]) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if the input to this view node is also a view node + input_view = node.args[0] + if not isinstance(input_view, torch.fx.Node): + return False - graph_module.recompile() + if ( + input_view.op != "call_function" + or input_view.target != exir_ops.edge.aten.view_copy.default + ): + return False - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.fuse_cascaded_view_ops(graph_module) - dead_code_elimination_pass(graph_module) - result = super().call(graph_module) - return result + # Replace the input of this view node with the input of the cascaded view + # This effectively "skips" the intermediate view node + node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0])) + return True class FuseOpPairsAcrossBranchesPass(ExportPass): + """ + Base class for passes that fuse op pairs across branches. + Provides common functionality for finding and fusing producer-consumer chains. + """ + def check_ok_to_fuse( self, producer: torch.fx.Node, @@ -606,7 +615,13 @@ def find_and_fuse( producer_op_packets: set[EdgeOpOverloadPacket], consumer_op_packets: set[EdgeOpOverloadPacket], bypass_ops: set[EdgeOpOverload], - ) -> None: + ) -> bool: + """ + Find and fuse producer-consumer op pairs. + + Returns True if any fusion was performed, False otherwise. + """ + modified = False for node in graph_module.graph.nodes: # We are only interested in ops that have overload target in # producer_op. @@ -629,8 +644,12 @@ def find_and_fuse( continue self.fuse(node, removal_candidates, graph_module) + modified = True - graph_module.recompile() + if modified: + graph_module.recompile() + + return modified def get_fused_node( self, @@ -773,14 +792,14 @@ def get_fused_node( def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Remove any dequantize op that has only quantize ops as its users. - self.find_and_fuse( + modified = self.find_and_fuse( graph_module, producer_op_packets=self.dequantize_op_packets, consumer_op_packets=self.quantize_op_packets, bypass_ops=self.bypass_ops, ) # Remove any quantize op that has only dequantze ops as its users. - self.find_and_fuse( + modified |= self.find_and_fuse( graph_module, producer_op_packets=self.quantize_op_packets, consumer_op_packets=self.dequantize_op_packets, @@ -793,35 +812,38 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: else {exir_ops.edge.aten.view_copy.default} ), ) - result = super().call(graph_module) - return result + if modified: + return super().call(graph_module) + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseMulScalarIntoDequantPass(ExportPass): +class FuseMulScalarIntoDequantPass(RemoveOrReplacePassInterface): """ Looks for the pattern where aten.mul.Scalar is multiplying the outputs of dequantize. If found, updates the dequant scale to reflect the multiplication and removes the mul node. """ - def attempt_fusion( - self, graph_module: torch.fx.GraphModule, node: torch.fx.Node - ) -> None: - if node.target not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, exir_ops.edge.cadence.dequantize_per_tensor.default, - }: - return + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Ensure that the single user of dequant is aten.mul.Scalar + if len(node.users) != 1: + return False - # ensure that the single user of dequant is aten.mul.Scalar user = list(node.users.keys())[0] - if len(node.users) != 1 or user.target != exir_ops.edge.aten.mul.Scalar: - return + if user.target != exir_ops.edge.aten.mul.Scalar: + return False - # ensure that the other arg to mul is a node (i.e. not a constant) + # Ensure that the other arg to mul is not a node (i.e. it's a constant) if len(user.args) > 1 and isinstance(user.args[1], torch.fx.Node): - return + return False new_deq_args = list(node.args) assert isinstance(node.args[1], Number) @@ -833,36 +855,36 @@ def attempt_fusion( f"Fused {node} and {user} into {node}. Updated scale from {node.args[1]} to {new_deq_args[1]}" ) + # Replace all uses of mul with the dequant node user.replace_all_uses_with(node) + # Update the dequant node's args with the new scale node.args = tuple(new_deq_args) - graph_module.graph.erase_node(user) - - graph_module.recompile() + # Erase the mul node + node.graph.erase_node(user) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - self.attempt_fusion(graph_module, node) - result = super().call(graph_module) - return result + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseMulTensorIntoQuantPass(ExportPass): +class FuseMulTensorIntoQuantPass(RemoveOrReplacePassInterface): """ Looks for the pattern where aten.mul.Tensor is followed by quant node. If found, updates the quant scale to reflect the multiplication and removes the mul node. """ - def attempt_fusion( - self, graph_module: torch.fx.GraphModule, mul_node: torch.fx.Node - ) -> None: - if len(mul_node.args) != 2 or len(mul_node.users) != 1: - return + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - first_arg = cast(torch.fx.Node, mul_node.args[0]) - second_arg = cast(torch.fx.Node, mul_node.args[1]) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check that mul has exactly 2 args and 1 user + if len(node.args) != 2 or len(node.users) != 1: + return False + + first_arg = cast(torch.fx.Node, node.args[0]) + second_arg = cast(torch.fx.Node, node.args[1]) input_node = first_arg full_node = second_arg @@ -875,20 +897,20 @@ def attempt_fusion( input_node = second_arg else: # Full node is not found, skip. - return + return False # Ensure that the mul op does not do any broadcasting. - if input_node.meta["val"].shape != mul_node.meta["val"].shape: - return + if input_node.meta["val"].shape != node.meta["val"].shape: + return False - mul_user = list(mul_node.users.keys())[0] + mul_user = list(node.users.keys())[0] # Ensure only the expected quant ops are using the current mul op. if mul_user.target not in { exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, exir_ops.edge.cadence.quantize_per_tensor.default, }: - return + return False quant_node = mul_user @@ -907,39 +929,32 @@ def attempt_fusion( new_scale = float(old_scale) / float(mul_scalar) logging.debug( - f"Fused {mul_node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}" + f"Fused {node} and {full_node} into {quant_node}. Updated scale from {quant_node.args[1]} to {new_scale}" ) # Update quant node input and scale. old_quant_input = cast(torch.fx.Node, quant_node.args[0]) - new_quant_input = cast(torch.fx.Node, mul_node.args[0]) + new_quant_input = input_node quant_node.replace_input_with(old_quant_input, new_quant_input) quant_node.update_arg(1, new_scale) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.find_nodes( - op="call_function", target=exir_ops.edge.aten.mul.Tensor - ): - self.attempt_fusion(graph_module, node) - graph_module.graph.eliminate_dead_code() - return super().call(graph_module) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseMulTensorIntoDequantPass(ExportPass): +class FuseMulTensorIntoDequantPass(RemoveOrReplacePassInterface): """ Looks for the pattern where aten.mul is multiplying the outputs of dequantize and aten.full, or vice versa. If found, updates the dequant scale to reflect the multiplication and removes the full and mul nodes. """ - def attempt_fusion( - self, graph_module: torch.fx.GraphModule, node: torch.fx.Node - ) -> None: - if node.target != exir_ops.edge.aten.mul.Tensor: - return + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # ensure that one of the args to mul is dequantize and the other is aten.full + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Ensure that one of the args to mul is dequantize and the other is aten.full dequant_nodes = [ arg for arg in node.args @@ -959,14 +974,14 @@ def attempt_fusion( ] if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1: - return + return False deq_node = dequant_nodes[0] mplier_node = multiplier_nodes[0] - # ensure that dequant and full don't have any other users + # Ensure that dequant and full don't have any other users if len(deq_node.users) > 1 or len(mplier_node.users) > 1: - return + return False new_deq_args = list(deq_node.args) assert isinstance(deq_node.args[1], Number) @@ -978,18 +993,16 @@ def attempt_fusion( f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}" ) + # Replace all uses of the mul node with the dequant node node.replace_all_uses_with(deq_node) + # Update the dequant node's args with the new scale deq_node.args = tuple(new_deq_args) - graph_module.graph.erase_node(node) - graph_module.graph.erase_node(mplier_node) - graph_module.recompile() + # Erase the mul and full nodes + node.graph.erase_node(node) + node.graph.erase_node(mplier_node) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - self.attempt_fusion(graph_module, node) - result = super().call(graph_module) - return result + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1058,7 +1071,7 @@ def get_fused_node( def call(self, graph_module: torch.fx.GraphModule) -> PassResult: # Remove any transpose/permutation op pair that cancel each other. - self.find_and_fuse( + modified = self.find_and_fuse( graph_module, producer_op_packets={ exir_ops.edge.aten.transpose_copy, @@ -1070,55 +1083,63 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: }, bypass_ops=self.bypass_ops, ) - result = super().call(graph_module) - return result + if modified: + return super().call(graph_module) + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class FuseFullThenReshapePass(ExportPass): +class FuseFullThenReshapePass(RemoveOrReplacePassInterface): """ A pass that fuses a chain of full and reshape-like operations into a single full operation. """ - fusion_candidates: set[EdgeOpOverload] = { - exir_ops.edge.aten.transpose_copy.int, - exir_ops.edge.aten.permute_copy.default, - exir_ops.edge.aten.view_copy.default, - } + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.transpose_copy.int, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.view_copy.default, + ] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if the input to this reshape-like node is a full node + full_node = node.args[0] + if not isinstance(full_node, torch.fx.Node): + return False - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in self.fusion_candidates: - return super().call_operator(op, args, kwargs, meta) - - full_node = cast(ProxyValue, args[0]).node if not ( full_node.op == "call_function" and full_node.target == exir_ops.edge.aten.full.default ): - # full -> self.fusion_candidates. - return super().call_operator(op, args, kwargs, meta) + return False + # Get the fill value from the full node fill_value = full_node.args[1] - return super().call_operator( - exir_ops.edge.aten.full.default, - ( - meta["val"].shape, - fill_value, - ), - {"dtype": meta["val"].dtype}, - meta, - ) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module = super().call(graph_module).graph_module - graph_module.graph.eliminate_dead_code() - return PassResult(graph_module, True) + # Get the output shape and dtype from this node's metadata + val = node.meta.get("val") + if val is None: + return False + + output_shape = val.shape + output_dtype = val.dtype + + graph = node.graph + + # Create a new full node with the final shape + with graph.inserting_before(node): + new_full_node = graph.call_function( + exir_ops.edge.aten.full.default, + args=(output_shape, fill_value), + kwargs={"dtype": output_dtype}, + ) + new_full_node.meta = node.meta + + # Replace all uses of the reshape node with the new full node + node.replace_all_uses_with(new_full_node) + + return True class CadenceFuseOpsInGraph: diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index 441ec58d80a..44acec8c373 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -2439,6 +2439,11 @@ def transposed_im2row_meta( in_zero_point: torch.Tensor, channel_last: bool = False, ) -> torch.Tensor: + """ + Shape inference for transposed_im2row operation. + + Returns shape: (N, H_out * W_out, K_h * K_w * C_in) + """ if len(input.shape) == 3: height_dim = 1 if channel_last else 2 input = input.unsqueeze(height_dim) @@ -2447,6 +2452,8 @@ def transposed_im2row_meta( n_input_plane = input.shape[3] if channel_last else input.shape[1] input_height = input.shape[1] if channel_last else input.shape[2] input_width = input.shape[2] if channel_last else input.shape[3] + + # Calculate output spatial dimensions output_height = ( (input_height - 1) * stride[0] - 2 * padding[0] @@ -2461,9 +2468,11 @@ def transposed_im2row_meta( + output_padding[1] + 1 ) - n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1] - output_length = output_height * output_width - output_size = torch.Size((batch_size, output_length, n_output_plane)) + + # Patch size is kernel_h * kernel_w * in_channels + patch_size = kernel_size[0] * kernel_size[1] * n_input_plane + num_patches = output_height * output_width + output_size = torch.Size((batch_size, num_patches, patch_size)) return input.new_empty(output_size, dtype=input.dtype) diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 00400403983..abad2c40553 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -1411,7 +1411,22 @@ def transposed_convolution( channel_last: bool = False, ) -> torch.Tensor: + # Cadence transposed conv receives weights that have been transformed by the pass: + # 1. Transposed (dims 0 and 1 swapped): [out_channels, in_channels, *kernel] + # 2. Flipped (spatial dimensions reversed) + # We need to reverse both transformations to call PyTorch's conv_transpose + conv_is_1d = len(input_tensor.shape) == 3 + + # Determine flip dimensions based on weight dimensionality + weight_dim = len(weight.shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] + + # Reverse transformation step 1: Unflip the spatial dimensions + weight = torch.flip(weight, dims=flip_dims) + + # Reverse transformation step 2: Transpose back to PyTorch format [in, out, *kernel] + weight = weight.transpose(0, 1).contiguous() if channel_last: if conv_is_1d: input_tensor = input_tensor.movedim(-1, 1).contiguous() @@ -1856,12 +1871,13 @@ def transposed_im2row( channel_last: bool = False, ) -> torch.Tensor: """ - Converts input tensor patches into im2row format for transposed convolutions. - This function extracts patches from input in a pattern suitable for transposed convolution. + Converts input tensor into im2row format for transposed convolutions. + For each output position, extracts the kernel-sized patch of input values that + contribute to that position in a transposed convolution. Args: - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D). - - kernel_size: Size of the convolution kernel. + - kernel_size: Size of the convolution kernel (kernel_h, kernel_w). - dilation: Dilation of the convolution kernel. - padding: Padding to apply to the input. - stride: Stride of the convolution. @@ -1886,117 +1902,136 @@ def transposed_im2row( input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW N, C, H_in, W_in = input_tensor.shape - - # Output: (N, C*H_in*W_in, H_out, W_out) - H_out = ( - (H_in - 1) * stride[0] - + kernel_size[0] - + output_padding[0] - - 2 * padding[0] - + dilation[0] * (kernel_size[0] - 1) - ) - W_out = ( - (W_in - 1) * stride[1] - + kernel_size[1] - + output_padding[1] - - 2 * padding[1] - + dilation[1] * (kernel_size[1] - 1) - ) - - # For each input pixel, create a channel where the upsampled (transposed conv) patch is placed - # Output: (N, C*H_in*W_in, H_out, W_out) - inp_flat = input_tensor.reshape(N, C * H_in * W_in) + K_h, K_w = kernel_size + device = input_tensor.device # Calculate output spatial size H_out = ( (H_in - 1) * stride[0] - 2 * padding[0] - + dilation[0] * (kernel_size[0] - 1) + + dilation[0] * (K_h - 1) + output_padding[0] + 1 ) W_out = ( (W_in - 1) * stride[1] - 2 * padding[1] - + dilation[1] * (kernel_size[1] - 1) + + dilation[1] * (K_w - 1) + output_padding[1] + 1 ) - # Compute the upsampled (top-left) position for each input pixel - h_idx = torch.arange(H_in, device=input_tensor.device) - w_idx = torch.arange(W_in, device=input_tensor.device) - grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij") - out_h_idx = grid_h * stride[0] - padding[0] - out_w_idx = grid_w * stride[1] - padding[1] - - # Compute all input pixel positions (flattened) - ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device) - ij_idx = ch_idx % (H_in * W_in) - i_idx = ij_idx // W_in - j_idx = ij_idx % W_in - - # For each input pixel, compute the output positions for the kernel window - kh_idx = torch.arange(kernel_size[0], device=input_tensor.device) - kw_idx = torch.arange(kernel_size[1], device=input_tensor.device) - kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij") - kh_grid = kh_grid.reshape(-1) - kw_grid = kw_grid.reshape(-1) - num_kernel = kernel_size[0] * kernel_size[1] - - # Broadcast to all channels and kernel positions - ch_idx_b = ch_idx.repeat_interleave(num_kernel) - n_kernel = ch_idx.shape[0] * num_kernel - - i_idx_b = i_idx.repeat_interleave(num_kernel) - j_idx_b = j_idx.repeat_interleave(num_kernel) - kh_b = kh_grid.repeat(ch_idx.shape[0]) - kw_b = kw_grid.repeat(ch_idx.shape[0]) - - h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0] - w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1] - - # Mask for valid output positions - valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out) - - # Prepare indices for advanced indexing - n_idx = ( - torch.arange(N, device=input_tensor.device) - .view(-1, 1) - .expand(N, n_kernel) - .reshape(-1) - ) - ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1) - h_out_full = h_out.expand(N, n_kernel).reshape(-1) - w_out_full = w_out.expand(N, n_kernel).reshape(-1) - valid_full = valid.expand(N, n_kernel).reshape(-1) - - # Gather input values for each channel - inp_vals = inp_flat[:, ch_idx_b].reshape(-1) - - # Create output tensor - patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype) + # Create meshgrids for all output positions and kernel positions + h_out_grid = torch.arange(H_out, device=device).view( + -1, 1, 1, 1 + ) # [H_out, 1, 1, 1] + w_out_grid = torch.arange(W_out, device=device).view( + 1, -1, 1, 1 + ) # [1, W_out, 1, 1] + kh_grid = torch.arange(K_h, device=device).view(1, 1, -1, 1) # [1, 1, K_h, 1] + kw_grid = torch.arange(K_w, device=device).view(1, 1, 1, -1) # [1, 1, 1, K_w] + + # Compute input positions for all (h_out, w_out, kh, kw) combinations + # From C++ reference: h_im = _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h + h_im = h_out_grid - (K_h - 1) * dilation[0] + kh_grid * dilation[0] + padding[0] + w_im = w_out_grid - (K_w - 1) * dilation[1] + kw_grid * dilation[1] + padding[1] + + # Check which positions are valid (divisible by stride and within bounds) + # From C++ reference: if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0) + h_valid = (h_im % stride[0] == 0) & (h_im >= 0) & (h_im < stride[0] * H_in) + w_valid = (w_im % stride[1] == 0) & (w_im >= 0) & (w_im < stride[1] * W_in) + valid = h_valid & w_valid # [H_out, W_out, K_h, K_w] + + # Actual input indices (h_im / stride_h from C++ reference) + h_in = h_im // stride[0] + w_in = w_im // stride[1] + + # Clamp indices to valid range (will be masked out anyway) + h_in_safe = h_in.clamp(0, H_in - 1) + w_in_safe = w_in.clamp(0, W_in - 1) + + # Initialize output patches with zero points (vectorized across batches) + # Layout depends on channel_last: NHWC uses [K_h, K_w, C], NCHW uses [C, K_h, K_w] + if channel_last: + # NHWC: patches layout [N, H_out, W_out, K_h, K_w, C] + patches = torch.zeros( + (N, H_out, W_out, K_h, K_w, C), + dtype=input_tensor.dtype, + device=device, + ) + else: + # NCHW: patches layout [N, H_out, W_out, C, K_h, K_w] + patches = torch.zeros( + (N, H_out, W_out, C, K_h, K_w), + dtype=input_tensor.dtype, + device=device, + ) - # If in_zero_point is provided, fill patches with it + # Initialize patches with zero points (vectorized) if in_zero_point is not None: if in_zero_point.numel() == 1: + # Scalar zero point - fill all patches patches.fill_(in_zero_point.item()) else: - # Broadcast in_zero_point to (N, C, H_in, W_in) - assert in_zero_point.shape == (N,) - in_zero_point = in_zero_point.view(N, 1, 1, 1) - patches = patches + in_zero_point - - # Scatter input values to output positions (only valid positions) - patches[ - n_idx[valid_full], - ch_idx_full[valid_full], - h_out_full[valid_full], - w_out_full[valid_full], - ] = inp_vals[valid_full] - - # Optionally, flatten to (N, num_patches, patch_size) if needed - patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous() + # Per-batch zero points - expand and fill + # in_zero_point: [N] -> [N, 1, 1, 1, 1, 1] or [N, 1, 1, 1, 1, 1] + zp_shape = [N] + [1] * (patches.ndim - 1) + patches = patches + in_zero_point.view(*zp_shape) + + # Flatten the spatial and kernel dimensions for efficient gathering + # h_in_safe, w_in_safe: [H_out, W_out, K_h, K_w] (broadcast shape) + h_flat = h_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1) + w_flat = w_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1) + + # Vectorized gathering across all batches and channels using advanced indexing + # Create index tensors with appropriate broadcasting shapes + num_positions = h_flat.shape[0] + + # batch_indices: [N, 1, 1] -> broadcasts to [N, C, num_positions] + batch_indices = torch.arange(N, device=device).view(N, 1, 1) + + # channel_indices: [1, C, 1] -> broadcasts to [N, C, num_positions] + channel_indices = torch.arange(C, device=device).view(1, C, 1) + + # h_flat, w_flat: [1, 1, num_positions] -> broadcasts to [N, C, num_positions] + h_indices = h_flat.view(1, 1, num_positions) + w_indices = w_flat.view(1, 1, num_positions) + + # Advanced indexing gathers all values at once: [N, C, num_positions] + gathered = input_tensor[batch_indices, channel_indices, h_indices, w_indices] + + # Reshape based on channel_last flag + if channel_last: + # NHWC: Reshape to [N, H_out, W_out, K_h, K_w, C] + # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, H_out*W_out*K_h*K_w, C] -> [N, H_out, W_out, K_h, K_w, C] + gathered = gathered.transpose(1, 2).contiguous() # [N, num_positions, C] + gathered = gathered.view(N, H_out, W_out, K_h, K_w, C) + else: + # NCHW: Reshape to [N, H_out, W_out, C, K_h, K_w] + # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, C, H_out, W_out, K_h, K_w] -> [N, H_out, W_out, C, K_h, K_w] + gathered = gathered.view(N, C, H_out, W_out, K_h, K_w) + gathered = gathered.permute(0, 2, 3, 1, 4, 5).contiguous() + + # Apply validity mask (vectorized across batches) + # valid: [H_out, W_out, K_h, K_w] -> expand to match gathered shape + if channel_last: + # gathered: [N, H_out, W_out, K_h, K_w, C] + valid_exp = valid.unsqueeze(0).unsqueeze(-1) # [1, H_out, W_out, K_h, K_w, 1] + else: + # gathered: [N, H_out, W_out, C, K_h, K_w] + valid_exp = valid.unsqueeze(0).unsqueeze(3) # [1, H_out, W_out, 1, K_h, K_w] + + patches = torch.where(valid_exp, gathered, patches) + + # Reshape to final output format: [N, H_out * W_out, K_h * K_w * C] + # The reshaping will preserve the correct dimension ordering + if channel_last: + # patches: [N, H_out, W_out, K_h, K_w, C] -> [N, H_out*W_out, K_h*K_w*C] + patches = patches.view(N, H_out * W_out, K_h * K_w * C) + else: + # patches: [N, H_out, W_out, C, K_h, K_w] -> [N, H_out*W_out, C*K_h*K_w] + patches = patches.view(N, H_out * W_out, C * K_h * K_w) + return patches diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 9dc695c68af..de5e046328e 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -27,11 +27,10 @@ from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.pass_manager import PassManager, PassType from executorch.exir.passes import dead_code_elimination_pass -from executorch.exir.passes.spec_prop_pass import SpecPropPass -from torch.fx.node import Argument, Node +from torch.fx.node import Node @register_cadence_pass(CadencePassAttribute(opt_level=0)) @@ -246,7 +245,7 @@ def maybe_remove_or_replace(self, node: Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class RemoveNopSelectOpPass(ExportPass): +class RemoveNopSelectOpPass(RemoveOrReplacePassInterface): """ A select op that selects from a dimension that is size 1 can be eliminated in a few cases. For example, @@ -273,87 +272,57 @@ class RemoveNopSelectOpPass(ExportPass): exir_ops.edge.aten.div.Tensor, } - def __init__(self) -> None: - super().__init__() - self.op_sizes: dict[str, tuple[torch.Size, torch.Size]] = {} + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.select_copy.int] - # For select, view, or any op in binary_broadcast_ops, record the shapes of - # input and output tensors. - def call_operator( - self, - op, # pyre-ignore - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - res = super().call_operator(op, args, kwargs, meta) - # Unary ops: input and output - if op in { - exir_ops.edge.aten.select_copy.int, - exir_ops.edge.aten.view_copy.default, - }: - arg0 = cast(ProxyValue, args[0]) - self.op_sizes[res.node.name] = (arg0.to_tensor().shape, meta["val"].shape) - # Binary ops: two inputs, output shape can be inferred - elif op in self.binary_broadcast_ops: - arg0 = cast(ProxyValue, args[0]) - arg1 = cast(ProxyValue, args[1]) - self.op_sizes[res.node.name] = ( - arg0.to_tensor().shape, - arg1.to_tensor().shape, - ) - return res - - # Eliminate nop select ops. We begin by inspecting the binary_broadcast_ops, - # and check if their arg is a select op. - def eliminate_nop_select_op(self, graph_module: torch.fx.GraphModule) -> None: - for sel_node in graph_module.graph.nodes: - # We are only interested in select ops - if sel_node.target != exir_ops.edge.aten.select_copy.int: - continue - # The shape of the input/output operands for this select op should - # have been precomputed. - assert sel_node.name in self.op_sizes - (sel_in_shape, sel_out_shape) = self.op_sizes[sel_node.name] - # Get the select dimension - sel_dim = ( - sel_node.args[1] - if sel_node.args[1] >= 0 - else sel_node.args[1] + len(sel_in_shape) - ) - # If the input size along select dimension is not 1, bail. - if sel_in_shape[sel_dim] != 1: - continue + def maybe_remove_or_replace(self, node: Node) -> bool: + # Get the select input node and shapes + sel_in_node = node.args[0] + assert isinstance(sel_in_node, Node) - # Get all the users of the select op that are either view, or - # binary_broadcast_ops. - users = [x for x in list(sel_node.users.keys()) if x.name in self.op_sizes] - sel_in = sel_node.args[0] - - # Iterate over the users of select op, and remove the use of the - # select op in the user if feasible. - for node in users: - args = list(node.args) - for idx, sel_arg in enumerate(args): - # Check if the arg is the select op - if sel_arg != sel_node: - continue - # If the input of select has the same shape as the other arg - # of the binary op, the select op can be bypassed. - if sel_in_shape == self.op_sizes[node.name][(idx + 1) % 2]: - args[idx] = sel_in - # update the node's args - node.args = tuple(args) - - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + sel_in_shape = sel_in_node.meta["val"].shape - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - result = SpecPropPass()(graph_module) - assert result is not None - result = super().call(result.graph_module) - self.eliminate_nop_select_op(result.graph_module) - return result + # Get the select dimension + sel_dim = cast(int, node.args[1]) + if sel_dim < 0: + sel_dim += len(sel_in_shape) + + # If the input size along select dimension is not 1, bail. + if sel_in_shape[sel_dim] != 1: + return False + + # Check if ALL users of the select op can be bypassed. + # A user can be bypassed if: + # 1. It's a view_copy op, OR + # 2. It's a binary_broadcast_op and the other operand has the same shape as sel_in + for user_node in node.users.keys(): + can_bypass = False + + # View ops can always bypass the select + if user_node.target == exir_ops.edge.aten.view_copy.default: + can_bypass = True + # For binary ops, check if the other operand has the same shape + elif user_node.target in self.binary_broadcast_ops: + # Find which argument is the select node + for idx, arg in enumerate(user_node.args): + if arg == node: + # Get the other argument + other_idx = (idx + 1) % 2 + other_arg = user_node.args[other_idx] + if isinstance(other_arg, Node): + other_shape = other_arg.meta["val"].shape + if sel_in_shape == other_shape: + can_bypass = True + break + + # If any user cannot be bypassed, we can't remove this select + if not can_bypass: + return False + + # All users can be bypassed, so replace the select node with its input + node.replace_all_uses_with(sel_in_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -547,10 +516,7 @@ class Subgraph: def call(self, graph_module: torch.fx.GraphModule) -> PassResult: subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = [] processed_nodes: set[torch.fx.Node] = set() - for node in graph_module.graph.nodes: - if node.target != exir_ops.edge.aten.permute_copy.default: - continue - + for node in graph_module.graph.find_nodes(op="call_function", target=exir_ops.edge.aten.permute_copy.default): start_permute = self.get_permutation(node) # Expected end permutation for the subgraph. end_permute = [start_permute.index(i) for i in range(len(start_permute))] @@ -566,13 +532,18 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: for node in subgraph.nodes: processed_nodes.add(node) + modified = False for subgraph in subgraphs_found: self.permute_subgraph(subgraph) + modified = True - graph_module.graph.eliminate_dead_code() - graph_module.recompile() + if modified: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() - return super().call(graph_module) + return super().call(graph_module) + + return PassResult(graph_module, False) def visit( self, diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index ccadd1e7a88..ff2c5e4c200 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -15,7 +15,7 @@ import math import operator from operator import neg -from typing import cast, Dict, Iterable, Optional, Sequence, Tuple +from typing import cast, Dict, Iterable, Optional, Sequence import torch import torch.fx @@ -30,7 +30,6 @@ ) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, - none_throws, register_cadence_pass, RemoveOrReplacePassInterface, ) @@ -41,8 +40,7 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue -from torch.fx.node import Argument +from executorch.exir.pass_base import ExportPass, PassResult # A map to represent ops that: # (a) are functionally equivalent; and @@ -741,7 +739,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Make that pass runnable standalone at opt level 0. @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenConvolutionWithCadenceConvolutionPass(ExportPass): +class ReplaceAtenConvolutionWithCadenceConvolutionPass(RemoveOrReplacePassInterface): """ Replace aten convolution op with jarvis-specific convolution op, since the aten version is not supported by jarvis. @@ -750,11 +748,14 @@ class ReplaceAtenConvolutionWithCadenceConvolutionPass(ExportPass): for unit-stride convolutions. """ - def call_operator(self, op, args, kwargs, meta): - if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.convolution.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # There must be 9 total args. - assert len(args) == 9 + if len(node.args) != 9: + return False # Unpack the args ( @@ -767,11 +768,18 @@ def call_operator(self, op, args, kwargs, meta): transposed, output_padding, groups, - ) = args + ) = node.args + + # Cast to appropriate types + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) + dilation = cast(Sequence[int], dilation) + output_padding = cast(Sequence[int], output_padding) + # Currently we only handle conversion to conv1d, conv2d, and conv3d, therefore # verify that the stride, padding, dilation, and output_padding have # len <=3. - assert ( + if not ( (len(stride) == len(padding) == len(dilation) == len(output_padding) == 1) or ( len(stride) == len(padding) == len(dilation) == len(output_padding) == 2 @@ -779,7 +787,8 @@ def call_operator(self, op, args, kwargs, meta): or ( len(stride) == len(padding) == len(dilation) == len(output_padding) == 3 ) - ), "Can only map convolution to conv1d, conv2d, and conv3d at present" + ): + return False # Determine if this is 1D, 2D, or 3D convolution based on parameter lengths if transposed: @@ -791,66 +800,62 @@ def call_operator(self, op, args, kwargs, meta): else: # len(stride) == 3 target = exir_ops.edge.cadence.conv3d.default - if transposed: - # Flip the height and width dimensions of weight, since we apply a - # gather stencil. Also, the first two dimensions of weight must be - # transposed/interchanged. - # If weight is a ProxyValue, new_weight needs to be the output of a - # graph operation (in this case a transpose_copy op) to be an explicit - # ProxyValue as well. If not, the view op can be done directly on the - # tensor. - transposed_weight = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - ( - weight, - 0, - 1, - ), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + if transposed: + # Flip the height and width dimensions of weight, since we apply a + # gather stencil. Also, the first two dimensions of weight must be + # transposed/interchanged. + assert isinstance(weight, torch.fx.Node) + transposed_weight = node.graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(weight, 0, 1), + ) + transposed_weight.meta = weight.meta - flipped_weight = super().call_operator( - exir_ops.edge.aten.flip.default, - ( - transposed_weight, - [-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2], - ), - kwargs, - meta, - ) + # Get the dimension for flip based on weight shape + weight_dim = len(weight.meta["val"].shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] - new_args = ( - in_tensor, - flipped_weight, - bias, - stride, - padding, - dilation, - output_padding, - groups, - False, - ) - else: - # Verify that output_padding is 0. - assert all( - x == 0 for x in output_padding - ), f"Cannot handle padded output in convolution. Got {output_padding=}" + flipped_weight = node.graph.call_function( + exir_ops.edge.aten.flip.default, + args=(transposed_weight, flip_dims), + ) + flipped_weight.meta = transposed_weight.meta - # Keep the original stride to maintain correct output dimensions - new_stride = stride + new_args = ( + in_tensor, + flipped_weight, + bias, + stride, + padding, + dilation, + output_padding, + groups, + False, + ) + else: + # Verify that output_padding is 0. + if not all(x == 0 for x in output_padding): + return False - new_args = ( - in_tensor, - weight, - bias, - new_stride, - padding, - dilation, - groups, - ) + # Keep the original stride to maintain correct output dimensions + new_stride = stride + + new_args = ( + in_tensor, + weight, + bias, + new_stride, + padding, + dilation, + groups, + ) - return super().call_operator(target, new_args, kwargs, meta) + new_node = node.graph.call_function(target, args=new_args) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -1014,131 +1019,197 @@ def canonicalize_transposed_dim(dim: int, shape: Sequence[int]) -> int: return dim -class ExportPassWithTransposeHelper(ExportPass): - def transpose_dims( - self: ExportPass, proxy: ProxyValue, meta: NodeMetadata, dim0: int, dim1: int - ) -> ProxyValue: - """Helper function to transpose dims of a `proxy` with given `meta`.""" - shape = proxy.data.shape +@register_cadence_pass(CadencePassAttribute(opt_level=3)) +class ReplaceConvWithChannelLastConvPass(RemoveOrReplacePassInterface): + """ + Replace NCHW convolutions with NHWC (channel-last) convolutions by adding + transpose operations before and after the convolution. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.cadence.conv1d.default, + exir_ops.edge.cadence.conv2d.default, + exir_ops.edge.cadence.conv3d.default, + exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape dim0, dim1 = ( canonicalize_transposed_dim(dim0, shape), canonicalize_transposed_dim(dim1, shape), ) dim0, dim1 = min(dim0, dim1), max(dim0, dim1) - return super().call_operator( - exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} ) - - -@register_cadence_pass(CadencePassAttribute(opt_level=3)) -class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper): - def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + transpose_node.meta = node.meta + return transpose_node + + def _change_nchw_to_nhwc( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NCHW format to NHWC format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0]] + indices[2:] + [indices[1]] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) - - def change_nhwc_to_nchw(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue: - shape = proxy.to_tensor().shape + permute_node.meta = node.meta + return permute_node + + def _change_nhwc_to_nchw( + self, graph: torch.fx.Graph, node: torch.fx.Node + ) -> torch.fx.Node: + """Convert NHWC format to NCHW format.""" + shape = node.meta["val"].shape if len(shape) == 3: - return self.transpose_dims(proxy, meta, 1, -1) + return self._transpose_dims(graph, node, 1, -1) indices = list(range(len(shape))) permute_indices = [indices[0], indices[-1]] + indices[1:-1] - return super().call_operator( - exir_ops.edge.aten.permute_copy.default, (proxy, permute_indices), {}, meta + permute_node = graph.call_function( + exir_ops.edge.aten.permute_copy.default, (node, permute_indices), {} ) + permute_node.meta = node.meta + return permute_node - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.cadence.conv1d.default, - exir_ops.edge.cadence.conv2d.default, - exir_ops.edge.cadence.conv3d.default, - exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor, - }: - return super().call_operator(op, args, kwargs, meta) - - quantized_op = op == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + assert isinstance(node.target, EdgeOpOverload) + quantized_op = node.target == exir_ops.edge.cadence.quantized_conv2d_nchw.per_tensor - if not quantized_op and len(args) == 8 and args[-1] is True: - # Already in NHWC layout. - return super().call_operator(op, args, kwargs, meta) + # Check if already in NHWC layout + if not quantized_op and len(node.args) == 8 and node.args[-1] is True: + return False + # Determine the new op target if quantized_op: new_op = exir_ops.edge.cadence.quantized_conv2d_nhwc.per_tensor else: - # Determine if 1D or 2D convolution based on op - new_op = op + new_op = node.target - input_proxy = cast(ProxyValue, args[0]) - weight_proxy = cast(ProxyValue, args[1]) - input_proxy = self.change_nchw_to_nhwc(input_proxy, meta) - weight_proxy = self.change_nchw_to_nhwc(weight_proxy, meta) + graph = node.graph - # Non-quantized ops still need to set the last optional argument to True. - channel_last_arg = [] if quantized_op else [True] + # Get input and weight nodes + input_node = cast(torch.fx.Node, node.args[0]) + weight_node = cast(torch.fx.Node, node.args[1]) - new_args = ( - # Transposed input/weights. - (input_proxy, weight_proxy) - # All other args (bias, quant params, etc) - + tuple(args[2:]) - + tuple(channel_last_arg) - ) - output_proxy = super().call_operator(new_op, new_args, kwargs, meta) - nchw_proxy = self.change_nhwc_to_nchw(output_proxy, meta) - return nchw_proxy + # Insert transpose operations before the node + with graph.inserting_before(node): + # Convert input from NCHW to NHWC + input_nhwc = self._change_nchw_to_nhwc(graph, input_node) + # Convert weight from NCHW to NHWC + weight_nhwc = self._change_nchw_to_nhwc(graph, weight_node) + + # Non-quantized ops need to set the last optional argument to True + channel_last_arg = [] if quantized_op else [True] + + # Create new args with transposed input/weights + new_args = ( + (input_nhwc, weight_nhwc) + + tuple(node.args[2:]) + + tuple(channel_last_arg) + ) + + # Create the new conv operation + new_conv = graph.call_function(new_op, new_args, node.kwargs) + new_conv.meta = node.meta + + # Convert output back from NHWC to NCHW + nchw_output = self._change_nhwc_to_nchw(graph, new_conv) + + # Replace all uses with the final output + node.replace_all_uses_with(nchw_output) + return True @register_cadence_pass(CadencePassAttribute(opt_level=3)) -class MakeSliceAndCatDimOutermostPass(ExportPassWithTransposeHelper): - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { +class MakeSliceAndCatDimOutermostPass(RemoveOrReplacePassInterface): + """ + Make the slice/cat dimension the outermost dimension by adding transpose + operations before and after the slice/cat operation. + """ + + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.cat.default, exir_ops.edge.aten.slice_copy.Tensor, - }: - return super().call_operator(op, args, kwargs, meta) - dim = cast(int, args[1]) if len(args) > 1 else 0 - output_shape = meta["val"].shape + ] + + def _transpose_dims( + self, graph: torch.fx.Graph, node: torch.fx.Node, dim0: int, dim1: int + ) -> torch.fx.Node: + """Helper function to transpose dims of a node.""" + shape = node.meta["val"].shape + dim0, dim1 = ( + canonicalize_transposed_dim(dim0, shape), + canonicalize_transposed_dim(dim1, shape), + ) + dim0, dim1 = min(dim0, dim1), max(dim0, dim1) + transpose_node = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, (node, dim0, dim1), {} + ) + transpose_node.meta = node.meta + return transpose_node + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the dimension argument + dim = cast(int, node.args[1]) if len(node.args) > 1 else 0 + output_shape = node.meta["val"].shape + + # Canonicalize dim to be positive if dim < 0: - # Keep dim positive. dim += len(output_shape) + # Not needed if dim is already outermost or all dims before it are 1 if dim == 0 or math.prod(output_shape[:dim]) == 1: - # Not needed if dim is already outermost or all dims before it are 1. - return super().call_operator(op, (args[0], dim) + args[2:], kwargs, meta) - - if op == exir_ops.edge.aten.slice_copy.Tensor: - # Transpose -> slice. - slice_args = ( - self.transpose_dims(cast(ProxyValue, args[0]), meta, dim, 0), - 0, - ) + args[2:] - new_op = super().call_operator(op, slice_args, kwargs, meta) - else: - # (Transpose input0, Transpose input1, ...) -> cat. - cat_in_tensors = [ - self.transpose_dims(t, meta, dim, 0) - for t in cast(list[ProxyValue], args[0]) - ] - new_op = super().call_operator(op, (cat_in_tensors, 0), kwargs, meta) - # slice/cat -> transpose. - return self.transpose_dims(new_op, meta, 0, dim) + return False + + graph = node.graph + + with graph.inserting_before(node): + if node.target == exir_ops.edge.aten.slice_copy.Tensor: + # Transpose input -> slice with dim=0 -> transpose back + input_node = cast(torch.fx.Node, node.args[0]) + transposed_input = self._transpose_dims(graph, input_node, dim, 0) + + # Create slice operation with dim=0 + slice_args = (transposed_input, 0) + node.args[2:] + sliced = graph.call_function( + exir_ops.edge.aten.slice_copy.Tensor, slice_args, node.kwargs + ) + sliced.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, sliced, 0, dim) + else: + # Cat operation: transpose all inputs -> cat with dim=0 -> transpose back + cat_inputs = cast(list[torch.fx.Node], node.args[0]) + transposed_inputs = [ + self._transpose_dims(graph, t, dim, 0) + for t in cat_inputs + ] + + # Create cat operation with dim=0 + catted = graph.call_function( + exir_ops.edge.aten.cat.default, (transposed_inputs, 0), node.kwargs + ) + catted.meta = node.meta + + # Transpose back + result = self._transpose_dims(graph, catted, 0, dim) + + # Replace all uses with the final result + node.replace_all_uses_with(result) + return True @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -1334,7 +1405,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceTransposedConvWithLinearPass(ExportPass): +class ReplaceTransposedConvWithLinearPass(RemoveOrReplacePassInterface): """ Replace transposed convolution where groups=1 with transposed_im2row followed by a linear op. @@ -1347,15 +1418,20 @@ class ReplaceTransposedConvWithLinearPass(ExportPass): exir_ops.edge.cadence.quantized_transposed_conv.default: exir_ops.edge.cadence.quantized_linear.default, } - def call_operator(self, op, args, kwargs, meta): - if op not in self.transposed_conv_op_to_linear_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.transposed_conv_op_to_linear_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the relevant args from transposed_convolution node. - quantized_op = op == exir_ops.edge.cadence.quantized_transposed_conv.default - assert len(args) == ( - 16 if quantized_op else 9 - ), "Inconsistent args for transposed_convolution" + assert isinstance(node.target, EdgeOpOverload) + quantized_op = ( + node.target == exir_ops.edge.cadence.quantized_transposed_conv.default + ) + expected_args = 16 if quantized_op else 9 + if len(node.args) != expected_args: + return False + ( in_tensor, weight, @@ -1365,21 +1441,23 @@ def call_operator(self, op, args, kwargs, meta): dilation, output_padding, groups, - ) = args[0:8] + ) = node.args[0:8] # We do not replace depthwise transposed_convolution with gemm yet. if groups != 1: - return super().call_operator(op, args, kwargs, meta) + return False # Get the shapes - out_shape = meta["val"].shape - weight_shape = weight.to_tensor().shape - assert None not in {weight_shape, out_shape} + assert isinstance(weight, torch.fx.Node) + out_shape = node.meta["val"].shape + weight_shape = weight.meta["val"].shape + if None in {weight_shape, out_shape}: + return False # Determine if the transposed_convolution is NCHW or NHWC. The NHWC, # i.e., the channel_last layout is specified by the channel_last arg # of transposed_conv op, which is the last argument. - channel_last = args[-1] + channel_last = node.args[-1] # The weight tensor is [out_channels, in_channels, X] for NCHW layout, # and [out_channels, X, in_channels] for NHWC layout. Here, X is the # kernel_width for conv1d, and X = kernel_height * kernel_width for @@ -1388,22 +1466,35 @@ def call_operator(self, op, args, kwargs, meta): # If the transposed_convolution op was quantized, we need the input tensor's # zero_point for im2row. Otherwise in_zero_point defaults to a zero # tensor. + assert isinstance(in_tensor, torch.fx.Node) in_zero_point = ( - get_zero_point(in_tensor.to_tensor()) + get_zero_point(in_tensor.meta["val"]) if quantized_op else torch.tensor(0, dtype=torch.int32) ) + + # Cast to appropriate types + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) + dilation = cast(Sequence[int], dilation) + output_padding = cast(Sequence[int], output_padding) + # transposed_im2row expects every kernel parameter to be 2d. So we extend the # parameters for conv1d by prepending their default values. - stride = ([1] + stride) if len(stride) == 1 else stride - padding = ([0] + padding) if len(padding) == 1 else padding - dilation = ([1] + dilation) if len(dilation) == 1 else dilation - output_padding = ( - ([0] + output_padding) if len(output_padding) == 1 else output_padding + stride_list = ([1] + list(stride)) if len(stride) == 1 else list(stride) + padding_list = ([0] + list(padding)) if len(padding) == 1 else list(padding) + dilation_list = ([1] + list(dilation)) if len(dilation) == 1 else list(dilation) + output_padding_list = ( + ([0] + list(output_padding)) + if len(output_padding) == 1 + else list(output_padding) ) kernel_size = ([1] + kernel_size) if len(kernel_size) == 1 else kernel_size - # Assert that kernel size does not have a 0 - assert 0 not in kernel_size + # Check that kernel size does not have a 0 + if 0 in kernel_size: + return False + + graph = node.graph # Create a transposed_im2row node with the input. This will create a 2d # matrix of shape [out_height*out_weight, X*in_channels]. X is as @@ -1411,32 +1502,33 @@ def call_operator(self, op, args, kwargs, meta): transposed_im2row_args = ( in_tensor, kernel_size, - dilation, - padding, - stride, - output_padding, + dilation_list, + padding_list, + stride_list, + output_padding_list, in_zero_point, channel_last, ) - transposed_im2row = super().call_operator( - exir_ops.edge.cadence.transposed_im2row.default, - transposed_im2row_args, - kwargs, - meta, - ) + with graph.inserting_before(node): + transposed_im2row = graph.call_function( + exir_ops.edge.cadence.transposed_im2row.default, + args=transposed_im2row_args, + ) + transposed_im2row.meta = node.meta + # Reshape the weight to [out_channels, in_channels * X] K = math.prod(weight_shape[1:]) - # Weight is always a ProxyValue, so we need a view_copy operation - linear_weight = super().call_operator( - exir_ops.edge.aten.view_copy.default, - ( - weight, - [weight_shape[0], K], - ), - kwargs, - meta, - ) + # Weight is always a Node, so we need a view_copy operation + with graph.inserting_before(node): + linear_weight = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=( + weight, + [weight_shape[0], K], + ), + ) + linear_weight.meta = node.meta # Create the linear node, which multiplies the 3d input with 2d weight # tensors with bias addition. The outermost dimension of the input is @@ -1448,7 +1540,8 @@ def call_operator(self, op, args, kwargs, meta): bias_scale, out_scale, out_zero_point, - ) = args[8:13] + ) = node.args[8:13] + # pyre-ignore[58]: Division operands requantize_scale = bias_scale / out_scale (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale) linear_args = ( @@ -1464,58 +1557,67 @@ def call_operator(self, op, args, kwargs, meta): ) else: linear_args = (transposed_im2row, linear_weight, bias) - linear_res = super().call_operator( - self.transposed_conv_op_to_linear_op[op], - linear_args, - kwargs, - meta, - ) + + with graph.inserting_before(node): + linear_res = graph.call_function( + self.transposed_conv_op_to_linear_op[cast(EdgeOpOverload, node.target)], + args=linear_args, + ) + linear_res.meta = node.meta + # The output of linear is a 3D tensor. However, the output is in NHWC # layout by default, because an input vector of size X is multiplied # with the weight matrix, i.e., column values are contiguous. If the # channel_last is False, we want to transpose this output. if not channel_last: - linear_res = super().call_operator( - exir_ops.edge.aten.transpose_copy.int, - (linear_res, 1, 2), - kwargs, - meta, - ) + with graph.inserting_before(node): + linear_res = graph.call_function( + exir_ops.edge.aten.transpose_copy.int, + args=(linear_res, 1, 2), + ) + linear_res.meta = node.meta + # And finally, we want to view the 3D output of linear op as 4D tensor - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (linear_res, list(out_shape)), - kwargs, - meta, - ) + with graph.inserting_before(node): + out_res = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(linear_res, list(out_shape)), + ) + out_res.meta = node.meta + + node.replace_all_uses_with(out_res) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceNopTransposeOrPermuteWithViewPass(ExportPass): +class ReplaceNopTransposeOrPermuteWithViewPass(RemoveOrReplacePassInterface): """ If the transpose/permute op does not change the byte order (e.g., transpose/permute from Nx1xHxW to NxHx1xW), then it can be replaced by view op. """ - def call_operator(self, op, args, kwargs, meta): - # Only proceed for transpose or permute op. - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.transpose_copy.int, exir_ops.edge.aten.permute_copy.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Get the input tensor and shape - in_tensor = args[0].to_tensor() - in_shape = in_tensor.shape + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) + in_shape = in_tensor_node.meta["val"].shape # Get the output tensor shape - out_shape = meta["val"].shape + out_shape = node.meta["val"].shape - if op == exir_ops.edge.aten.transpose_copy.int: + if node.target == exir_ops.edge.aten.transpose_copy.int: # Get the two dims to be transposed - dim0 = args[1] if args[1] >= 0 else in_tensor.dim() + args[1] - dim1 = args[2] if args[2] >= 0 else in_tensor.dim() + args[2] + dim0 = cast(int, node.args[1]) + dim1 = cast(int, node.args[2]) + dim0 = dim0 if dim0 >= 0 else len(in_shape) + dim0 + dim1 = dim1 if dim1 >= 0 else len(in_shape) + dim1 # We can eliminate transpose if (a) the size at dim0 and dim1 is 1; # (b) the size at dim0 or dim1 is 1, and dim0 and dim1 are consecutive. both_one = in_shape[dim0] == 1 and in_shape[dim1] == 1 @@ -1523,17 +1625,22 @@ def call_operator(self, op, args, kwargs, meta): in_shape[dim0] == 1 or in_shape[dim1] == 1 ) if both_one or either_one_and_consecutive: - new_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta - ) - - elif op == exir_ops.edge.aten.permute_copy.default: - old_dims = list(range(in_tensor.dim())) - new_dims = args[1] + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True + + elif node.target == exir_ops.edge.aten.permute_copy.default: + old_dims = list(range(len(in_shape))) + new_dims = cast(Sequence[int], node.args[1]) # If the permute does not change anything, return the input as output. - if old_dims == new_dims: - return args[0] + if old_dims == list(new_dims): + node.replace_all_uses_with(in_tensor_node) + return True # Get the old dim order, and the permuted dim order for all dims that # are not 1. old_order = [ @@ -1544,22 +1651,30 @@ def call_operator(self, op, args, kwargs, meta): ] # If the byte ordering for non-unit dims is unchanged, this is a nop. if old_order == new_order: - new_args = (args[0], list(out_shape)) - return super().call_operator( - exir_ops.edge.aten.view_copy.default, new_args, kwargs, meta - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, list(out_shape)), + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + return True - return super().call_operator(op, args, kwargs, meta) + return False def call(self, graph_module: torch.fx.GraphModule) -> PassResult: result = super().call(graph_module) - fuse_cascaded_result = none_throws(FuseCascadedViewOps()(result.graph_module)) - result = none_throws(ExportPass()(fuse_cascaded_result.graph_module)) + # If this pass made modifications, fuse any cascaded view ops that may have been created + if result.modified: + fuse_cascaded_result = FuseCascadedViewOps().call(result.graph_module) + + # True because we are in the 'if modified' block + return PassResult(fuse_cascaded_result.graph_module, True) return result @register_cadence_pass(CadencePassAttribute(opt_level=2)) -class ReplaceLinearWithFullyConnectedOpPass(ExportPass): +class ReplaceLinearWithFullyConnectedOpPass(RemoveOrReplacePassInterface): """ If the input of linear/quantized_linear op is a vector, replace it with fully_connected op. @@ -1570,25 +1685,32 @@ class ReplaceLinearWithFullyConnectedOpPass(ExportPass): exir_ops.edge.cadence.quantized_linear.default: exir_ops.edge.cadence.quantized_fully_connected.default, } - def call_operator(self, op, args, kwargs, meta): - # Only proceed for linear or quantized_linear ops. - if op not in self.linear_to_fc_op: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return list(self.linear_to_fc_op.keys()) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Extract the input tensor - in_tensor = args[0].to_tensor() - leading_dims = math.prod(in_tensor.shape[:-1]) + in_tensor_arg = node.args[0] + assert isinstance(in_tensor_arg, torch.fx.Node) + in_tensor_shape = in_tensor_arg.meta["val"].shape + leading_dims = math.prod(in_tensor_shape[:-1]) # If the tensor is not a vector, do nothing. if leading_dims != 1: - return super().call_operator(op, args, kwargs, meta) + return False # Replace the linear with fully connected op - return super().call_operator( - self.linear_to_fc_op[op], - args, - kwargs, - meta, - ) + assert isinstance(node.target, EdgeOpOverload) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + self.linear_to_fc_op[cast(EdgeOpOverload, node.target)], + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True register_cadence_pass(CadencePassAttribute(opt_level=0))(ReplaceScalarWithTensorArgPass) @@ -1687,58 +1809,73 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(ExportPass): +class ReplaceAtenAvgPoolWithCadenceAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten avg_pool op with the cadence custom avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.avg_pool1d.default, exir_ops.edge.aten.avg_pool2d.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Determine if the op is avg_pool1d or avg_pool2d - avg_pool1d: bool = op == exir_ops.edge.aten.avg_pool1d.default - # Get the input tensor - in_tensor = args[0].to_tensor() + avg_pool1d: bool = node.target == exir_ops.edge.aten.avg_pool1d.default + + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) # Replace avg_pool2d with custom avg_pool2d, and if the input tensor is # quantized, pass its zero_point tensor as arg to the custom avg_pool2d. # stride, padding, ceil_mode, count_include_pad, divisor_override, are # the native avg_pool2d args. 'channel_last' denotes NCHW vs NHWC layout, # and is False by default. - kernel_size = args[1] - stride = args[2] if len(args) >= 3 else [1, 1] - padding = args[3] if len(args) >= 4 else [0, 0] - ceil_mode = args[4] if len(args) >= 5 else False - count_include_pad = args[5] if len(args) >= 6 else True - divisor_override = args[6] if len(args) >= 7 else None - zero_point = args[7] if len(args) >= 8 else None + kernel_size = node.args[1] + # When stride is not provided or is empty, PyTorch defaults to kernel_size + stride = node.args[2] if len(node.args) >= 3 and node.args[2] else kernel_size + padding = node.args[3] if len(node.args) >= 4 else [0, 0] + ceil_mode = node.args[4] if len(node.args) >= 5 else False + count_include_pad = node.args[5] if len(node.args) >= 6 else True + divisor_override = node.args[6] if len(node.args) >= 7 else None + zero_point = node.args[7] if len(node.args) >= 8 else None + + graph = node.graph + out_shape = node.meta["val"].shape + + kernel_size = cast(Sequence[int], kernel_size) + stride = cast(Sequence[int], stride) + padding = cast(Sequence[int], padding) # If the op is avg_pool1d, then we need to reshape the 3d input to a 4d # tensor. if avg_pool1d: - in_shape = list(in_tensor.shape) + in_shape = list(in_tensor_node.meta["val"].shape) assert len(in_shape) == 3, "Expected 3d input for avg_pool1d" - in_shape.insert(2, 1) - out_shape = meta["val"].shape - in_view_op = super().call_operator( - exir_ops.edge.aten.view_copy.default, - (in_tensor, in_shape), - kwargs, - meta, - ) + in_shape_4d = in_shape[:2] + [1] + in_shape[2:] + + with graph.inserting_before(node): + in_view_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(in_tensor_node, in_shape_4d), + ) + in_view_node.meta = node.meta + # Extend the kernel_size, stride and padding to 2d - kernel_size = [1] + kernel_size if len(kernel_size) == 1 else kernel_size - stride = [1] + stride if len(stride) == 1 else stride - padding = [0] + padding if len(padding) == 1 else padding + kernel_size = [1] + list(kernel_size) if len(kernel_size) == 1 else kernel_size + stride = [1] + list(stride) if len(stride) == 1 else stride + padding = [0] + list(padding) if len(padding) == 1 else padding + + input_for_pool = in_view_node + else: + input_for_pool = in_tensor_node # Create a new avg_pool node with the updated args new_args = ( - in_view_op if avg_pool1d else args[0], + input_for_pool, kernel_size, stride, padding, @@ -1748,70 +1885,66 @@ def call_operator(self, op, args, kwargs, meta): zero_point, False, ) - avg_pool2d_op = super().call_operator( - exir_ops.edge.cadence.avg_pool2d.default, - new_args, - kwargs, - meta, - ) - # If the node was avg_pool1d, we again reshape the 4d output to 3d output - return ( - super().call_operator( - exir_ops.edge.aten.view_copy.default, - (avg_pool2d_op, list(out_shape)), - kwargs, - meta, + with graph.inserting_before(node): + avg_pool2d_node = graph.call_function( + exir_ops.edge.cadence.avg_pool2d.default, + args=new_args, ) - if avg_pool1d - else avg_pool2d_op - ) + avg_pool2d_node.meta = node.meta + + # If the node was avg_pool1d, we again reshape the 4d output to 3d output + if avg_pool1d: + with graph.inserting_before(node): + result_node = graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(avg_pool2d_node, list(out_shape)), + ) + result_node.meta = node.meta + node.replace_all_uses_with(result_node) + else: + node.replace_all_uses_with(avg_pool2d_node) + + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceIm2RowWithViewPass(ExportPass): - def can_replace(self, op, args, kwargs, meta) -> bool: - if op != exir_ops.edge.cadence.im2row.default: - return False +class ReplaceIm2RowWithViewPass(RemoveOrReplacePassInterface): + """ + Replace im2row with view when possible (no padding, no dilation, and output spatial dimensions are 1). + """ + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.cadence.im2row.default] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: # Check if im2row applies padding. If yes, we cannot replace it with view. - pad = cast(tuple[int, ...], args[3]) + pad = cast(Sequence[int], node.args[3]) if any(p != 0 for p in pad): return False # Check if im2row has dilation. If yes, we cannot replace it with view. - dilation = cast(tuple[int, ...], args[2]) + dilation = cast(Sequence[int], node.args[2]) if any(d != 1 for d in dilation): return False # im2row works on 3D or 4D tensors. # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. - output_shape = meta["val"].shape - if math.prod(output_shape[1:-1]) == 1: - return True + output_shape = node.meta["val"].shape + if math.prod(output_shape[1:-1]) != 1: + return False - return False + # Replace im2row with view_copy + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.view_copy.default, + args=(node.args[0], list(output_shape)), + ) + new_node.meta = node.meta - def call_operator( - self, - op, - args: tuple[Argument, ...], - kwargs: dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op != exir_ops.edge.cadence.im2row.default: - return super().call_operator(op, args, kwargs, meta) - - if not self.can_replace(op, args, kwargs, meta): - return super().call_operator(op, args, kwargs, meta) - - output_shape = meta["val"].shape - return super().call_operator( - exir_ops.edge.aten.view_copy.default, - (args[0], tuple(output_shape)), - kwargs, - meta, - ) + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) @@ -1830,57 +1963,84 @@ def call_operator(self, op, args, kwargs, meta): return super().call_operator(op, args, kwargs, meta) def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - ret = super().call(graph_module) - modified = ret.graph_module.graph.eliminate_dead_code() or ret.modified - return PassResult(ret.graph_module, modified) + changed = False + for module in filter( + lambda m: isinstance(m, torch.fx.GraphModule), graph_module.modules() + ): + module = cast(torch.fx.GraphModule, module) + for node in module.graph.nodes: + if node.op != "call_function": + continue + val = node.meta.get("val", None) + if isinstance(val, torch.Tensor) and val.numel() == 0: + with module.graph.inserting_before(node): + new_node = module.graph.call_function( + exir_ops.edge.aten.full.default, + args=(val.shape, 0), + kwargs={"dtype": val.dtype}, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + changed = True + + if changed: + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return super().call(graph_module) + + return PassResult(graph_module, False) @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceWhereWithFullArgsWithWhereScalar(ExportPass): +class ReplaceWhereWithFullArgsWithWhereScalar(RemoveOrReplacePassInterface): """Replaces where ops using two full ops as tensors with a scalar version. """ - def call_operator( - self, - op, - args: Tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.where.self, - }: - return super().call_operator(op, args, kwargs, meta) - - # If the args are not full ops, bail - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - if (args[1].node.target != exir_ops.edge.aten.full.default) or ( - args[2].node.target != exir_ops.edge.aten.full.default - ): - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.where.self] + + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Check if args[1] and args[2] are full ops + arg1 = node.args[1] + arg2 = node.args[2] + + if not isinstance(arg1, torch.fx.Node) or not isinstance(arg2, torch.fx.Node): + return False - # If one of the full ops is a different size than than the cond tensor, we need to broadcast. Bail. if ( - # pyre-ignore[16]: `ProxyValue` has no attribute `node`. - list(args[0].to_tensor().shape) != args[1].node.args[0] - or list(args[0].to_tensor().shape) != args[2].node.args[0] + arg1.target != exir_ops.edge.aten.full.default + or arg2.target != exir_ops.edge.aten.full.default ): - return super().call_operator(op, args, kwargs, meta) + return False + + # Get the condition tensor shape + cond_arg = node.args[0] + assert isinstance(cond_arg, torch.fx.Node) + cond_shape = list(cond_arg.meta["val"].shape) + + # Check if the full ops have the same size as the cond tensor + full1_shape = arg1.args[0] + full2_shape = arg2.args[0] + + if cond_shape != full1_shape or cond_shape != full2_shape: + return False # Get the scalar values from the full ops - scalar_value_1 = args[1].node.args[1] - scalar_value_2 = args[2].node.args[1] + scalar_value_1 = arg1.args[1] + scalar_value_2 = arg2.args[1] # Replace the where op with a scalar where op - return super().call_operator( - exir_ops.edge.cadence.where_Scalar.default, - (args[0], scalar_value_1, scalar_value_2), - kwargs, - meta, - ) + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.cadence.where_Scalar.default, + args=(cond_arg, scalar_value_1, scalar_value_2), + ) + new_node.meta = node.meta - return super().call_operator(op, args, kwargs, meta) + node.replace_all_uses_with(new_node) + return True # Adapted from fbcode/pyspeech/opt_passes/replace_ops.py @@ -2116,53 +2276,56 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class ReplaceMulTensorWithMulAndFullOpsPass(ExportPass): +class ReplaceMulTensorWithMulAndFullOpsPass(RemoveOrReplacePassInterface): """ Extracts a single value argument of mul op to a separate full op. """ - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for mul_node in graph_module.graph.find_nodes( - op="call_function", target=torch.ops.aten.mul.Tensor - ): - x_arg, const_arg = mul_node.args + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.mul.Tensor] - # Swap arguments if the order is wrong - if isinstance(const_arg, torch.fx.Node): - x_arg, const_arg = const_arg, x_arg + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + x_arg, const_arg = node.args - # Skip if the const_arg is not a scalar - if not isinstance(const_arg, (float, int)) or not isinstance( - x_arg, torch.fx.Node - ): - continue + # Swap arguments if the order is wrong + if isinstance(const_arg, torch.fx.Node): + x_arg, const_arg = const_arg, x_arg - # Cast the const_arg to the dtype of the x_arg - full_arg = self.resolve_full_arg(x_arg, const_arg) + # Skip if the const_arg is not a scalar + if not isinstance(const_arg, (float, int)) or not isinstance( + x_arg, torch.fx.Node + ): + return False - full_output_dtype = ( - torch.int32 if isinstance(full_arg, int) else torch.float32 - ) + # Cast the const_arg to the dtype of the x_arg + full_arg = self.resolve_full_arg(x_arg, const_arg) - # Extract an argument to a separate full op. - with graph_module.graph.inserting_before(mul_node): - full_node = graph_module.graph.call_function( - torch.ops.aten.full.default, - args=([1], full_arg), - kwargs={"dtype": full_output_dtype}, - ) - full_node.meta = mul_node.meta - full_node.meta["val"] = [1] - new_mul_node = graph_module.graph.call_function( - torch.ops.aten.mul.Tensor, args=(x_arg, full_node) - ) - new_mul_node.meta = mul_node.meta - # Replace the old mul with a newly created mul. - mul_node.replace_all_uses_with(new_mul_node) - graph_module.graph.erase_node(mul_node) - return super().call(graph_module) + full_output_dtype = ( + torch.int32 if isinstance(full_arg, int) else torch.float32 + ) + + # Extract an argument to a separate full op. + with node.graph.inserting_before(node): + full_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + args=([1], full_arg), + kwargs={"dtype": full_output_dtype}, + ) + full_node.meta = node.meta + full_node.meta["val"] = [1] + new_mul_node = node.graph.call_function( + exir_ops.edge.aten.mul.Tensor, args=(x_arg, full_node) + ) + new_mul_node.meta = node.meta + # Replace the old mul with a newly created mul. + node.replace_all_uses_with(new_mul_node) + node.graph.erase_node(node) + return True - def resolve_full_arg(self, x_arg, const_arg): + def resolve_full_arg( + self, x_arg: torch.fx.Node, const_arg: float | int + ) -> float | int: if x_arg.meta["val"].dtype == torch.float32 and isinstance(const_arg, int): const_arg = float(const_arg) if x_arg.meta["val"].dtype == torch.int32 and isinstance(const_arg, float): @@ -2171,40 +2334,41 @@ def resolve_full_arg(self, x_arg, const_arg): @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass): +class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(RemoveOrReplacePassInterface): """ Replace the aten adaptive avg_pool op with the aten avg_pool2d op. """ - def call_operator(self, op, args, kwargs, meta): - # Only continue for avg_pool op - if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}: - return super().call_operator(op, args, kwargs, meta) + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten._adaptive_avg_pool2d.default] - # Get the input tensor - in_tensor = args[0].to_tensor() - # Permute NCHW to NHWC for computation - in_tensor_permuted = in_tensor.permute(0, 2, 3, 1) - in_tensor_shape = in_tensor_permuted.shape + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # Get the input tensor node + in_tensor_node = node.args[0] + assert isinstance(in_tensor_node, torch.fx.Node) - output_size = args[1] + # Get input shape (in NCHW format) + in_shape = in_tensor_node.meta["val"].shape + output_size = cast(Sequence[int], node.args[1]) num_dims = len(output_size) + # Spatial dimensions are at indices [2:] for NCHW format # TODO: If in_tensor_shape is not a multiple of output size, # this pass will not work. T224984800 dim_multiples = [ - (in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims) + (in_shape[i + 2] % output_size[i]) == 0 for i in range(num_dims) ] if not all(dim_multiples): logging.info( - f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}" + f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_shape} is not a multiple of output size: {output_size}" ) - return super().call_operator(op, args, kwargs, meta) + return False - # Compute stride and kernel_size, then set default values for other arguments - stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)] + # Compute stride and kernel_size based on spatial dimensions + stride = [(in_shape[i + 2] // output_size[i]) for i in range(num_dims)] kernel_size = [ - in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i] + in_shape[i + 2] - (output_size[i] - 1) * stride[i] for i in range(num_dims) ] padding = [0] * num_dims @@ -2212,9 +2376,9 @@ def call_operator(self, op, args, kwargs, meta): count_include_pad = True divisor_override = None - # Create a new avg_pool node with the updated args + # Create a new avg_pool2d node with the computed args new_args = ( - args[0], + in_tensor_node, kernel_size, stride, padding, @@ -2222,12 +2386,16 @@ def call_operator(self, op, args, kwargs, meta): count_include_pad, divisor_override, ) - return super().call_operator( - exir_ops.edge.aten.avg_pool2d.default, - new_args, - kwargs, - meta, - ) + + with node.graph.inserting_before(node): + new_node = node.graph.call_function( + exir_ops.edge.aten.avg_pool2d.default, + args=new_args, + ) + new_node.meta = node.meta + + node.replace_all_uses_with(new_node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index 980acee5b66..ebc178cefc4 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -96,7 +96,9 @@ def test_no_fuse_for_3d_bias(self) -> None: original_graph = builder.get_graph_module() p = FuseMMWithAdd() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertFalse(result.modified) + converted_graph = result.graph_module converted_graph.graph.eliminate_dead_code() self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 0 @@ -106,9 +108,12 @@ def test_no_fuse_for_3d_bias(self) -> None: def test_fuse_mm_with_add(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn(5, 6, dtype=torch.float32)) - z = builder.placeholder("z", torch.randn(6, dtype=torch.float32)) + x_input = torch.randn(3, 5, dtype=torch.float32) + y_input = torch.randn(5, 6, dtype=torch.float32) + z_input = torch.randn(6, dtype=torch.float32) + x = builder.placeholder("x", x_input) + y = builder.placeholder("y", y_input) + z = builder.placeholder("z", z_input) mm = builder.call_operator( op=exir_ops.edge.aten.mm.default, args=(x, y), @@ -116,10 +121,18 @@ def test_fuse_mm_with_add(self) -> None: output = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(mm, z)) builder.output([output]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseMMWithAdd() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module converted_graph.graph.eliminate_dead_code() + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input, y_input, z_input), "FuseMMWithAdd") + self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.addmm.default), 1 ) @@ -273,7 +286,8 @@ def test_permute_transpose_fusion(self) -> None: def test_view_fusion(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + x_input = torch.randn(8, 5, 3, dtype=torch.float32) + x = builder.placeholder("x", x_input) view1 = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) ) @@ -285,9 +299,17 @@ def test_view_fusion(self) -> None: ) builder.output([output]) original_graph = builder.get_graph_module() + + gm_before = copy.deepcopy(original_graph) p = FuseCascadedViewOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate(gm_before, converted_graph, inputs, "FuseCascadedViewOps") + # Assert that only one view op remains self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 1 @@ -295,7 +317,8 @@ def test_view_fusion(self) -> None: def test_view_fusion_branched(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(8, 5, 3, dtype=torch.float32)) + x_input = torch.randn(8, 5, 3, dtype=torch.float32) + x = builder.placeholder("x", x_input) y = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 8, 15]) ) @@ -307,9 +330,17 @@ def test_view_fusion_branched(self) -> None: ) builder.output([z, t]) original_graph = builder.get_graph_module() + + gm_before = copy.deepcopy(original_graph) p = FuseCascadedViewOps() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate(gm_before, converted_graph, inputs, "FuseCascadedViewOps") + # z and t should be fused and y should be eliminated. self.assertEqual( count_node(converted_graph, exir_ops.edge.aten.view_copy.default), 2 @@ -361,7 +392,9 @@ def test_no_replace_quant_permute_dequant_with_requantize(self) -> None: original_graph = builder.get_graph_module() p = FuseQuantDequantToRequantizePass(force_quant_dequant_fusion=False) - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertFalse(result.modified) + converted_graph = result.graph_module self.check_op_counts( converted_graph, expected_op_counts={ @@ -403,7 +436,8 @@ def test_replace_quant_view_dequant_with_requantize(self) -> None: def test_replace_dequant_quant_with_requantize(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) + x_input = torch.randint(low=0, high=5, size=(2, 12, 1, 6), dtype=torch.int8) + x = builder.placeholder("x", x_input) dequant = builder.call_operator( op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(x, 1.2, 3, 0, 127, torch.int8), @@ -414,8 +448,15 @@ def test_replace_dequant_quant_with_requantize(self) -> None: ) builder.output([quant]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseQuantDequantToRequantizePass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input,), "FuseQuantDequantToRequantizePass") self.check_op_counts( converted_graph, @@ -429,7 +470,8 @@ def test_replace_dequant_quant_with_requantize(self) -> None: def test_replace_dequant_permute_quant_with_requantize(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32)) + x_input = torch.randint(low=0, high=5, size=(2, 12, 1, 6), dtype=torch.int8) + x = builder.placeholder("x", x_input) dequant = builder.call_operator( op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(x, 1.2, 3, 0, 127, torch.int8), @@ -443,8 +485,15 @@ def test_replace_dequant_permute_quant_with_requantize(self) -> None: ) builder.output([quant]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseQuantDequantToRequantizePass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input,), "FuseQuantDequantToRequantizePass") self.check_op_counts( converted_graph, @@ -550,7 +599,8 @@ def test_fuse_mul_into_dequant(self) -> None: FULL_VALUE: Final[float] = 3 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32)) + x_input = torch.randint(low=0, high=255, size=INPUT_SHAPE, dtype=torch.uint8) + x = builder.placeholder("x", x_input) dequant = builder.call_operator( op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8), @@ -565,8 +615,15 @@ def test_fuse_mul_into_dequant(self) -> None: ) builder.output([mul]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseMulTensorIntoDequantPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input,), "FuseMulTensorIntoDequantPass") # verify that the mul and full ops were removed self.check_op_counts( @@ -593,7 +650,8 @@ def test_fuse_mul_scalar_into_dequant(self) -> None: mul_value = 0.3 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32)) + x_input = torch.randn(2, 3, 4, dtype=torch.float32) + x = builder.placeholder("x", x_input) quant = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(x, 1, 0, -128, 127, torch.int8), @@ -608,8 +666,15 @@ def test_fuse_mul_scalar_into_dequant(self) -> None: ) builder.output([mul_scalar]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseMulScalarIntoDequantPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input,), "FuseMulScalarIntoDequantPass") # verify that the mul and full ops were removed self.check_op_counts( @@ -635,7 +700,8 @@ def test_fuse_mul_into_quant(self) -> None: mul_value = 10 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(4, 32, dtype=torch.float32)) + x_input = torch.randn(4, 32, dtype=torch.float32) + x = builder.placeholder("x", x_input) full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], mul_value), @@ -650,8 +716,15 @@ def test_fuse_mul_into_quant(self) -> None: ) builder.output([quant]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + p = FuseMulTensorIntoQuantPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # Validate numerical accuracy + validate(gm_before, converted_graph, (x_input,), "FuseMulTensorIntoQuantPass") # verify that the mul and full ops were removed self.check_op_counts( @@ -671,14 +744,8 @@ def test_fuse_mul_into_quant(self) -> None: new_quant_scale = node.args[1] self.assertEqual(new_quant_scale, quant_scale / mul_value) - # verify the math is correct - inp = torch.randn(4, 32, dtype=torch.float32) - original_out = original_graph(inp)[0] - new_out = converted_graph(inp)[0] - assert torch.equal(original_out, new_out) - def test_fuse_then_transpose_pass(self) -> None: - # Create a graph with full -> transpose. + # Create a graph with full -> transpose -> permute -> view. builder = GraphBuilder() full_node = builder.call_operator( op=exir_ops.edge.aten.full.default, args=((2, 3), 1) @@ -697,6 +764,8 @@ def test_fuse_then_transpose_pass(self) -> None: ) builder.output([view_node]) original_graph = builder.get_graph_module() + gm_before = copy.deepcopy(original_graph) + self.check_op_counts( original_graph, expected_op_counts={ @@ -709,7 +778,13 @@ def test_fuse_then_transpose_pass(self) -> None: # Check that the pass fuses the full with all other ops (transpose, permute, view). p = FuseFullThenReshapePass() - gm_after_pass = cast(PassResult, p(original_graph)).graph_module + result = cast(PassResult, p(original_graph)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(gm_before, gm_after_pass, [], "FuseFullThenReshapePass") + self.check_op_counts( gm_after_pass, expected_op_counts={ diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 24bbe7ee644..d3ea1695735 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -1813,7 +1813,7 @@ def test_convolution( ), # input: 1x1x2x2 torch.tensor( [[[[1.0, 1.0], [1.0, 1.0]]]], dtype=torch.float32 - ), # weight: 1x1x2x2 + ), # weight: 1x1x2x2 (in PyTorch format, will be transformed to Cadence format) torch.tensor([0.0], dtype=torch.float32), # bias (1, 1), # stride (0, 0), # padding @@ -1834,7 +1834,7 @@ def test_convolution( ), # input: 1x1x2x2 torch.tensor( [[[[1.0, 0.0], [0.0, 1.0]]]], dtype=torch.float32 - ), # weight: 1x1x2x2 + ), # weight: 1x1x2x2 (in PyTorch format, will be transformed to Cadence format) torch.tensor([5.0], dtype=torch.float32), # bias=5.0 (1, 1), # stride (0, 0), # padding @@ -1863,9 +1863,19 @@ def test_transposed_convolution( channel_last: bool, expected_output: torch.Tensor, ) -> None: + # Apply the same transformations that ReplaceAtenConvolutionWithCadenceConvolutionPass + # applies to weights: transpose(0,1) then flip spatial dimensions. + # This converts weights from PyTorch format to Cadence format. + weight_dim = len(weight.shape) + flip_dims = [-1] if weight_dim == 3 else [-1, -2] + + # Transform: transpose dims 0 and 1, then flip spatial dimensions + cadence_weight = weight.transpose(0, 1) + cadence_weight = torch.flip(cadence_weight, dims=flip_dims) + output = torch.ops.cadence.transposed_convolution( input_tensor, - weight, + cadence_weight, bias, stride, padding, @@ -2374,15 +2384,15 @@ def test_im2row( torch.tensor( [ [ - [1, 0, 0, 0], - [1, 2, 0, 0], - [0, 2, 0, 0], - [1, 0, 3, 0], + [0, 0, 0, 1], + [0, 0, 1, 2], + [0, 0, 2, 0], + [0, 1, 0, 3], [1, 2, 3, 4], - [0, 2, 0, 4], - [0, 0, 3, 0], - [0, 0, 3, 4], - [0, 0, 0, 4], + [2, 0, 4, 0], + [0, 3, 0, 0], + [3, 4, 0, 0], + [4, 0, 0, 0], ] ], dtype=torch.int32, @@ -2401,15 +2411,15 @@ def test_im2row( torch.tensor( [ [ - [1, 100, 100, 100], - [1, 2, 100, 100], - [100, 2, 100, 100], - [1, 100, 3, 100], + [100, 100, 100, 1], + [100, 100, 1, 2], + [100, 100, 2, 100], + [100, 1, 100, 3], [1, 2, 3, 4], - [100, 2, 100, 4], - [100, 100, 3, 100], - [100, 100, 3, 4], - [100, 100, 100, 4], + [2, 100, 4, 100], + [100, 3, 100, 100], + [3, 4, 100, 100], + [4, 100, 100, 100], ] ], dtype=torch.int32, @@ -2428,22 +2438,22 @@ def test_im2row( torch.tensor( [ [ + [0, 0, 0, 1], + [0, 0, 1, 0], + [0, 0, 0, 2], + [0, 0, 2, 0], + [0, 1, 0, 0], [1, 0, 0, 0], - [1, 0, 0, 0], - [0, 2, 0, 0], - [0, 2, 0, 0], - [1, 0, 0, 0], - [1, 0, 0, 0], - [0, 2, 0, 0], [0, 2, 0, 0], + [2, 0, 0, 0], + [0, 0, 0, 3], [0, 0, 3, 0], - [0, 0, 3, 0], - [0, 0, 0, 4], - [0, 0, 0, 4], - [0, 0, 3, 0], - [0, 0, 3, 0], - [0, 0, 0, 4], [0, 0, 0, 4], + [0, 0, 4, 0], + [0, 3, 0, 0], + [3, 0, 0, 0], + [0, 4, 0, 0], + [4, 0, 0, 0], ] ], dtype=torch.int32, @@ -2468,26 +2478,26 @@ def test_im2row( torch.tensor( [ [ - [1, 100, 100, 100], - [1, 2, 100, 100], - [100, 2, 100, 100], - [1, 100, 3, 100], + [100, 100, 100, 1], + [100, 100, 1, 2], + [100, 100, 2, 100], + [100, 1, 100, 3], [1, 2, 3, 4], - [100, 2, 100, 4], - [100, 100, 3, 100], - [100, 100, 3, 4], - [100, 100, 100, 4], + [2, 100, 4, 100], + [100, 3, 100, 100], + [3, 4, 100, 100], + [4, 100, 100, 100], ], [ - [5, 200, 200, 200], - [5, 6, 200, 200], - [200, 6, 200, 200], - [5, 200, 7, 200], + [200, 200, 200, 5], + [200, 200, 5, 6], + [200, 200, 6, 200], + [200, 5, 200, 7], [5, 6, 7, 8], - [200, 6, 200, 8], - [200, 200, 7, 200], - [200, 200, 7, 8], - [200, 200, 200, 8], + [6, 200, 8, 200], + [200, 7, 200, 200], + [7, 8, 200, 200], + [8, 200, 200, 200], ], ], dtype=torch.int32, diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index c957eb04b87..462f5c37036 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -7,14 +7,16 @@ # pyre-strict +import copy import unittest -from copy import deepcopy from typing import cast, List, Tuple import executorch.backends.cadence.aot.ops_registrations # noqa import torch from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass -from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.cadence.aot.graph_builder import ( + GraphBuilder, +) from executorch.backends.cadence.aot.pass_utils import count_node from executorch.backends.cadence.aot.remove_ops import ( @@ -41,6 +43,46 @@ from pyre_extensions import none_throws from torch.fx.passes.infra.pass_base import PassResult +from torch.utils import _pytree as pytree + + +def validate( + original: torch.fx.GraphModule, + modified: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> None: + """Validate that two graph modules produce numerically equivalent outputs. + + Args: + original: The original graph module before the pass + modified: The modified graph module after the pass + inputs: Input tensors to run through both graphs + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + """ + original.eval() + modified.eval() + with torch.no_grad(): + orig_out = original(*inputs) + mod_out = modified(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) class TestRemoveOpsPasses(unittest.TestCase): @@ -337,6 +379,10 @@ def test_remove_nop_select_before_view(self) -> None: ) builder.output([view]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + graph_after_passes = cast( PassResult, RemoveNopSelectOpPass()(original) ).graph_module @@ -344,6 +390,10 @@ def test_remove_nop_select_before_view(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) + # Verify numerical correctness + sample_input = [torch.randn(1, 5, 6, dtype=torch.float32)] + validate(gm_before, graph_after_passes, sample_input, "RemoveNopSelectOpPass") + def test_remove_nop_select_before_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -359,6 +409,10 @@ def test_remove_nop_select_before_add(self) -> None: add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y)) builder.output([add]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + graph_after_passes = cast( PassResult, RemoveNopSelectOpPass()(original) ).graph_module @@ -366,6 +420,13 @@ def test_remove_nop_select_before_add(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) + # Verify numerical correctness + sample_inputs = [ + torch.randn(1, 5, 6, dtype=torch.float32), + torch.randn(1, 5, 6, dtype=torch.float32), + ] + validate(gm_before, graph_after_passes, sample_inputs, "RemoveNopSelectOpPass") + def test_remove_nop_select_before_mul(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -381,6 +442,10 @@ def test_remove_nop_select_before_mul(self) -> None: mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y)) builder.output([mul]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + graph_after_passes = cast( PassResult, RemoveNopSelectOpPass()(original) ).graph_module @@ -388,6 +453,13 @@ def test_remove_nop_select_before_mul(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) + # Verify numerical correctness + sample_inputs = [ + torch.randn(1, 5, 6, dtype=torch.float32), + torch.randn(1, 5, 6, dtype=torch.float32), + ] + validate(gm_before, graph_after_passes, sample_inputs, "RemoveNopSelectOpPass") + def test_remove_nop_select_before_div(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32)) @@ -403,6 +475,10 @@ def test_remove_nop_select_before_div(self) -> None: div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y)) builder.output([div]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + graph_after_passes = cast( PassResult, RemoveNopSelectOpPass()(original) ).graph_module @@ -410,6 +486,13 @@ def test_remove_nop_select_before_div(self) -> None: count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 ) + # Verify numerical correctness + sample_inputs = [ + torch.randn(1, 5, 6, dtype=torch.float32), + torch.randn(1, 5, 6, dtype=torch.float32), + ] + validate(gm_before, graph_after_passes, sample_inputs, "RemoveNopSelectOpPass") + def test_remove_nop_quant_dequant(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 8)) @@ -494,12 +577,25 @@ def test_remove_permutes_around_elemwise_ops_add(self) -> None: ) builder.output([permute]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + p = RemovePermutesAroundElementwiseOps() graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 ) + # Verify numerical correctness + sample_inputs = [torch.randn(1, 8, 4, 4, dtype=torch.float32)] + validate( + gm_before, + graph_after_passes, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_keep_permutes_around_elemwise_ops_add(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 8, 4, 4, dtype=torch.float32)) @@ -542,6 +638,10 @@ def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None: ) builder.output([permute]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + p = RemovePermutesAroundElementwiseOps() graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( @@ -554,6 +654,18 @@ def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None: ][0] self.assertEqual(mean_op.args[1], [2, 3]) + # Verify numerical correctness + sample_inputs = [ + torch.randn(1, 8, 4, 4, dtype=torch.float32), + torch.randn(1, 8, 4, 4, dtype=torch.float32), + ] + validate( + gm_before, + graph_after_passes, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_remove_permutes_around_elemwise_ops_slice(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 8, 4, 4)) @@ -572,6 +684,9 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None: builder.output([output]) original = builder.get_graph_module() + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + p = RemovePermutesAroundElementwiseOps() graph_after_passes = cast(PassResult, p(original)).graph_module @@ -587,6 +702,15 @@ def test_remove_permutes_around_elemwise_ops_slice(self) -> None: self.assertEqual(len(slices), 1) self.assertEqual(slices[0].args[1], 2) + # Verify numerical correctness + sample_inputs = [torch.randn(1, 8, 4, 4)] + validate( + gm_before, + graph_after_passes, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_remove_squeeze_view_before_elemwise_ops(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(8, 1, 4, 4)) @@ -608,7 +732,7 @@ def test_remove_squeeze_view_before_elemwise_ops(self) -> None: ) builder.output([unsqueeze]) model = builder.get_graph_module() - original = deepcopy(model) + original = copy.deepcopy(model) p = RemoveSqueezeViewBeforeElementwiseOps() pass_result = cast(PassResult, p(model)) @@ -659,7 +783,7 @@ def test_remove_squeeze_view_before_elemwise_ops_multiple_squeeze(self) -> None: ) builder.output([view_copy]) model = builder.get_graph_module() - original = deepcopy(model) + original = copy.deepcopy(model) p = RemoveSqueezeViewBeforeElementwiseOps() transformed = cast(PassResult, p(model)).graph_module @@ -722,12 +846,28 @@ def test_remove_permutes_around_elemwise_ops_mul(self) -> None: ) builder.output([output]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + p = RemovePermutesAroundElementwiseOps() graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 ) + # Verify numerical correctness + sample_inputs = [ + torch.randint(0, 5, (2, 4, 4, 8), dtype=torch.uint8), + torch.randint(0, 5, (2, 4, 4, 8), dtype=torch.uint8), + ] + validate( + gm_before, + graph_after_passes, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(2, 4, 4, 8)) @@ -777,6 +917,10 @@ def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: ) builder.output([output]) original = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original) + p = RemovePermutesAroundElementwiseOps() graph_after_passes = cast(PassResult, p(original)).graph_module # Expect 2 permutes to remain, one on input x and one on output z @@ -791,6 +935,18 @@ def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: ][0] self.assertEqual(cat.args[1], 3) + # Verify numerical correctness + sample_inputs = [ + torch.randint(0, 5, (2, 4, 4, 8), dtype=torch.uint8), + torch.randint(0, 5, (1, 8, 4, 4), dtype=torch.uint8), + ] + validate( + gm_before, + graph_after_passes, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: """ A complicated case touching many edge cases. @@ -847,6 +1003,9 @@ def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: builder.output([f, g, h, i, k]) graph_module = builder.get_graph_module() + # Deepcopy before the pass + gm_before = copy.deepcopy(graph_module) + p = RemovePermutesAroundElementwiseOps() graph_module = cast(PassResult, p(graph_module)).graph_module @@ -855,6 +1014,19 @@ def test_remove_permutes_around_elemwise_ops_complicated_case(self) -> None: count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 4 ) + # Verify numerical correctness + sample_inputs = [ + torch.randn(1, 4, 4, 8), + torch.randn(1, 4, 4, 8), + torch.randn(1, 4, 4, 8), + ] + validate( + gm_before, + graph_module, + sample_inputs, + "RemovePermutesAroundElementwiseOps", + ) + def test_remove_dequant_on_branch(self) -> None: builder = GraphBuilder() x = builder.placeholder("x", torch.randn(1, 8, 4, 6)) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index cc891da4f46..19a09805fdf 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -23,6 +23,7 @@ MakeSliceAndCatDimOutermostPass, ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, + ReplaceAtenAvgPoolWithCadenceAvgPoolPass, ReplaceAtenConvolutionWithCadenceConvolutionPass, ReplaceAtenLinalgSvdWithCadenceLinalgSvdPass, ReplaceConstantPadNdWithSlicePass, @@ -459,9 +460,9 @@ def test_replace_aten_conv_with_cadence_conv( builder = GraphBuilder() x_tensor = torch.randn(*shape, dtype=torch.float32) x = builder.placeholder("x", x_tensor) - weights_tensor = torch.randn( - [out_channels, in_channels // groups, kernel], dtype=torch.float32 - ) + # For regular conv: weight shape is [out_channels, in_channels // groups, kernel] + weights_shape = [out_channels, in_channels // groups, kernel] + weights_tensor = torch.randn(weights_shape, dtype=torch.float32) weights = builder.placeholder("weights", weights_tensor) bias: Optional[ProxyValue] = None bias_tensor: Optional[torch.Tensor] = None @@ -485,12 +486,24 @@ def test_replace_aten_conv_with_cadence_conv( builder.output([convolution]) original_gm = builder.get_graph_module() - replacement_pass_result = ( - ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm) - ) + gm_before = copy.deepcopy(original_gm) + p = ReplaceAtenConvolutionWithCadenceConvolutionPass() + replacement_pass_result = cast(PassResult, p(original_gm)) self.assertIsNotNone(replacement_pass_result) + self.assertTrue(replacement_pass_result.modified) graph_after_passes = replacement_pass_result.graph_module + # Validate numerical accuracy + inputs = (x_tensor, weights_tensor) + if bias is not None: + inputs += (cast(torch.Tensor, bias_tensor),) + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass", + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 0, @@ -507,26 +520,18 @@ def test_replace_aten_conv_with_cadence_conv( 0, ) - inputs = (x.to_tensor(), weights.to_tensor()) - if bias is not None: - inputs += (bias.to_tensor(),) - self.assertTensorMetadataIsSame( - pytree.tree_flatten(original_gm.forward(*inputs))[0], - pytree.tree_flatten(graph_after_passes.forward(*inputs))[0], - ) - @expand( [ - [(1, 8, 18), 8, 16, 3], - [(1, 8, 18), 8, 16, 5, 2], + [(1, 8, 16), 8, 16, 3], + [(1, 8, 16), 8, 16, 5, 2], # depthwise + bias - [(1, 8, 18), 8, 16, 5, 2, 0, 1, True, True], + [(1, 8, 16), 8, 16, 5, 2, 0, 1, True, True], # no bias - [(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False], + [(1, 8, 16), 8, 16, 3, 2, 4, 3, False, False], # depthwise + no bias - [(1, 8, 18), 8, 16, 3, 1, 0, 1, True, False], + [(1, 8, 16), 8, 16, 3, 1, 0, 1, True, False], # bias - [(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True], + [(1, 8, 16), 8, 16, 5, 2, 0, 1, False, True], ] ) @torch.no_grad() @@ -545,16 +550,20 @@ def test_replace_aten_transposed_conv_with_cadence_transposed_conv( ) -> None: groups = in_channels if depthwise else 1 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + x_tensor = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_tensor) + # For transposed conv: weight shape is [in_channels, out_channels // groups, kernel] weights_shape = [in_channels, out_channels // groups, kernel] + weights_tensor = torch.randn(weights_shape, dtype=torch.float32) weights = builder.placeholder( "weights", - torch.randn(weights_shape, dtype=torch.float32), + weights_tensor, + ) + bias_tensor = ( + torch.randn([out_channels], dtype=torch.float32) if bias_enabled else None ) bias = ( - builder.placeholder( - "bias", torch.randn([out_channels], dtype=torch.float32) - ) + builder.placeholder("bias", cast(torch.Tensor, bias_tensor)) if bias_enabled else None ) @@ -574,13 +583,25 @@ def test_replace_aten_transposed_conv_with_cadence_transposed_conv( ) builder.output([convolution]) original_gm = builder.get_graph_module() + gm_before = copy.deepcopy(original_gm) - replacement_pass_result = ( - ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm) - ) + p = ReplaceAtenConvolutionWithCadenceConvolutionPass() + replacement_pass_result = cast(PassResult, p(original_gm)) self.assertIsNotNone(replacement_pass_result) + self.assertTrue(replacement_pass_result.modified) graph_after_passes = replacement_pass_result.graph_module + inputs = (x_tensor, weights_tensor) + if bias_tensor is not None: + inputs += (bias_tensor,) + + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass", + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 0, @@ -592,14 +613,6 @@ def test_replace_aten_transposed_conv_with_cadence_transposed_conv( 1, ) - inputs = (x.to_tensor(), weights.to_tensor()) - if bias is not None: - inputs += (bias.to_tensor(),) - self.assertTensorMetadataIsSame( - pytree.tree_flatten(original_gm.forward(*inputs))[0], - pytree.tree_flatten(graph_after_passes.forward(*inputs))[0], - ) - @expand( [ [(1, 8, 33), 8, 16, 3], @@ -628,15 +641,21 @@ def test_replace_transposed_conv_with_linear( output_padding = [0] groups = in_channels if depthwise else 1 builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + x_tensor = torch.randn(*shape, dtype=torch.float32) + x = builder.placeholder("x", x_tensor) + # For transposed conv: weight shape is [in_channels, out_channels // groups, kernel] + weights_tensor = torch.randn( + [in_channels, out_channels // groups, kernel], dtype=torch.float32 + ) weights = builder.placeholder( "weights", - torch.randn([in_channels, out_channels, kernel], dtype=torch.float32), + weights_tensor, + ) + bias_tensor = ( + torch.randn([out_channels], dtype=torch.float32) if bias_enabled else None ) bias = ( - builder.placeholder( - "bias", torch.randn([out_channels], dtype=torch.float32) - ) + builder.placeholder("bias", cast(torch.Tensor, bias_tensor)) if bias_enabled else None ) @@ -669,9 +688,37 @@ def test_replace_transposed_conv_with_linear( p1 = ReplaceAtenConvolutionWithCadenceConvolutionPass() p2 = ReplaceTransposedConvWithLinearPass() - graph_after_passes = cast( - PassResult, p2(cast(PassResult, p1(original_gm)).graph_module) - ).graph_module + + gm_before = copy.deepcopy(original_gm) + result1 = p1.call(original_gm) + self.assertTrue(result1.modified) + graph_after_p1 = result1.graph_module + + # Validate after first pass + inputs = ( + (x_tensor, weights_tensor) + if bias_tensor is None + else (x_tensor, weights_tensor, bias_tensor) + ) + validate( + gm_before, + graph_after_p1, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass", + ) + + result2 = p2.call(graph_after_p1) + self.assertTrue(result2.modified) + graph_after_passes = result2.graph_module + + # Validate after second pass (end-to-end) + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceAtenConvolutionWithCadenceConvolutionPass and ReplaceTransposedConvWithLinearPass", + ) + self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1, @@ -1316,8 +1363,21 @@ def test_replace_nop_transpose_with_view( op=exir_ops.edge.aten.transpose_copy.int, args=(x, dim0, dim1), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceNopTransposeOrPermuteWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceNopTransposeOrPermuteWithViewPass", + ) # Assert that transpose op was removed, and a view op was placed instead self.assertEqual( @@ -1344,8 +1404,21 @@ def test_replace_nop_permute_with_view( op=exir_ops.edge.aten.permute_copy.default, args=(x, dims), ) + + gm_before = copy.deepcopy(original_gm) p = ReplaceNopTransposeOrPermuteWithViewPass() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceNopTransposeOrPermuteWithViewPass", + ) # Assert that permute op was removed, and a view op was placed instead self.assertEqual( @@ -1412,9 +1485,11 @@ def test_replace_permute_with_transpose_nop( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) +class TestReplaceWhereWithFullArgsWithWhereScalar(unittest.TestCase): def test_replace_aten_where_with_cadence(self) -> None: builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(4, 8)) + cond_input = torch.randn(4, 8) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1433,8 +1508,24 @@ def test_replace_aten_where_with_cadence(self) -> None: ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) + p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) + self.assertEqual( count_node( graph_after_passes, @@ -1462,9 +1553,9 @@ def test_replace_aten_where_with_cadence_broadcast( val1: float, val2: float, ) -> None: - # cond_shape, a_shape, b_shape, val1, val2 = builder = GraphBuilder() - cond = builder.placeholder("cond", torch.randn(cond_shape)) + cond_input = torch.randn(cond_shape) + cond = builder.placeholder("cond", cond_input) aten_gt_scalar = builder.call_operator( op=exir_ops.edge.aten.gt.Scalar, args=(cond, 0), @@ -1483,8 +1574,25 @@ def test_replace_aten_where_with_cadence_broadcast( ) builder.output([aten_where_self]) original_gm = builder.get_graph_module() + + # Deepcopy before the pass + gm_before = copy.deepcopy(original_gm) + p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(original_gm)).graph_module + result = cast(PassResult, p(original_gm)) + # Broadcast case should not be replaced + self.assertFalse(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy (should be same since not modified) + inputs = [cond_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceWhereWithFullArgsWithWhereScalar", + ) + self.assertEqual( count_node( graph_after_passes, @@ -1600,7 +1708,7 @@ class TestReplaceIm2rowWithViewPass(unittest.TestCase): def test_no_replacement_for_conv(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 224, 224) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1612,9 +1720,19 @@ def test_no_replacement_for_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # Check that no replacement was made. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 @@ -1626,7 +1744,7 @@ def test_no_replacement_for_conv(self) -> None: def test_no_replace_for_dilation(self) -> None: # Create a graph with a single im2row node. x = torch.randn(1, 3, 5, 7) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1638,9 +1756,19 @@ def test_no_replace_for_dilation(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 ) @@ -1652,7 +1780,7 @@ def test_replace_linear_like_conv(self) -> None: # Create a graph with a single im2row node. in_h, in_w = 13, 15 x = torch.randn(1, 3, in_h, in_w) - pad_value = torch.randn(1) + pad_value = torch.tensor(0, dtype=torch.int32) channels_last = False gm = single_op_builder( placeholders=(x, pad_value), @@ -1664,9 +1792,19 @@ def test_replace_linear_like_conv(self) -> None: self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceIm2RowWithViewPass() - gm_after_replacement = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + # In this test, the kernel width/height is the same as the input width/height. self.assertEqual( count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0 @@ -1793,11 +1931,24 @@ def test_quantized_convolution_default_channel_last(self) -> None: ) self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = ReplaceConvWithChannelLastConvPass() - original = copy.deepcopy(gm) - gm_after_replacement = p.call(gm).graph_module - # Check that no replacement was made. + result = p.call(gm) + self.assertTrue(result.modified) + gm_after_replacement = result.graph_module + + # Validate numerical accuracy + validate( + original, + gm_after_replacement, + placeholders, + "ReplaceConvWithChannelLastConvPass", + ) + + # Check that replacement was made. self.assertEqual( count_node( gm_after_replacement, @@ -1810,12 +1961,6 @@ def test_quantized_convolution_default_channel_last(self) -> None: count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 3, ) - validate( - gm_after_replacement, - original, - placeholders, - "ReplaceConvWithChannelLastConvPass", - ) def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None: # Create a graph with a single im2row node. @@ -1866,14 +2011,23 @@ def create_slice_graph( def test_slice_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node. + x = torch.randn(3, 224, 224) gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. self.assertEqual( @@ -1883,14 +2037,23 @@ def test_slice_no_transpose_if_already_outermost(self) -> None: def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1901,14 +2064,23 @@ def test_slice_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_slice_insert_transpose(self) -> None: # Create a graph with a single slice node. + x = torch.randn(1, 3, 4, 6) gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [x], "MakeSliceAndCatDimOutermostPass") # Assert that there are two transpose ops added. self.assertEqual( @@ -1930,14 +2102,24 @@ def create_cat_graph( def test_cat_no_transpose_if_already_outermost(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 3, 5) + input2 = torch.randn(2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1948,14 +2130,24 @@ def test_cat_no_transpose_if_already_outermost(self) -> None: def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 2, 3, 5) gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) # Check if graph module is valid by running exportpass on it. gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertFalse(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass") # Assert that no transpose ops were added. The slice is on the second # outermost dimension, but the outermost dimension is already 1. @@ -1966,6 +2158,8 @@ def test_cat_no_transpose_if_outermost_dimensions_are_one(self) -> None: def test_cat_insert_transpose(self) -> None: # Create a graph with a single slice node on second outermost dimension. + input1 = torch.randn(1, 1, 3, 5) + input2 = torch.randn(1, 1, 3, 3) gm = self.create_cat_graph( input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 ) @@ -1973,9 +2167,17 @@ def test_cat_insert_transpose(self) -> None: gm = ExportPass().call(gm).graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) + # Deepcopy before the pass + original = copy.deepcopy(gm) + # Apply replacement pass. p = MakeSliceAndCatDimOutermostPass() - gm_after_pass = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + gm_after_pass = result.graph_module + + # Validate numerical accuracy + validate(original, gm_after_pass, [input1, input2], "MakeSliceAndCatDimOutermostPass") # Assert that transpose ops were added to make cat on outermost dimension. self.assertEqual( @@ -1985,9 +2187,10 @@ def test_cat_insert_transpose(self) -> None: class TestReplaceEmptyTensorsWithFullPass(unittest.TestCase): - def _get_slice_empty_gm(self) -> torch.fx.GraphModule: + def _get_slice_empty_gm(self) -> tuple[torch.fx.GraphModule, torch.Tensor]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(4)) + x_input = torch.randn(4) + x = builder.placeholder("x", x_input) # This is empty (numel == 0). slice0 = builder.call_operator( exir_ops.edge.aten.slice_copy.Tensor, (x, 0, 0, 0) @@ -1999,10 +2202,10 @@ def _get_slice_empty_gm(self) -> torch.fx.GraphModule: ((slice0, slice1),), ) builder.output([cat]) - return builder.get_graph_module() + return builder.get_graph_module(), x_input def test_empty_slice(self) -> None: - gm = self._get_slice_empty_gm() + gm, x_input = self._get_slice_empty_gm() self.assertEqual( len( gm.graph.find_nodes( @@ -2019,8 +2222,19 @@ def test_empty_slice(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceEmptyTensorsWithFullPass() - updated_gm = cast(PassResult, p(gm)).graph_module + result = cast(PassResult, p(gm)) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate(gm_before, updated_gm, inputs, "ReplaceEmptyTensorsWithFullPass") + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2048,21 +2262,37 @@ def test_empty_slice(self) -> None: def test_extract_mul_argument_to_full( self, _: str, value: Union[int, float] ) -> None: - x = torch.randn(2, 1, 64) + x_input = torch.randn(2, 1, 64) gm = single_op_builder( - placeholders=(x,), - op=torch.ops.aten.mul.Tensor, - args=(x, value), + placeholders=(x_input,), + op=exir_ops.edge.aten.mul.Tensor, + args=(x_input, value), kwargs={}, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceMulTensorWithMulAndFullOpsPass() - graph_after_passes = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + graph_after_passes = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + graph_after_passes, + inputs, + "ReplaceMulTensorWithMulAndFullOpsPass", + ) + self.assertTrue( op_counts_match( graph_after_passes, expected_op_counts={ - torch.ops.aten.mul.Tensor: 1, - torch.ops.aten.full.default: 1, + exir_ops.edge.aten.mul.Tensor: 1, + exir_ops.edge.aten.full.default: 1, }, ) ) @@ -2071,17 +2301,18 @@ def test_extract_mul_argument_to_full( class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): def _get_adaptive_avg_pool_gm( self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int] - ) -> torch.fx.GraphModule: + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(*input_shape)) + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) adaptive_avg_pool2d = builder.call_operator( exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) ) builder.output([adaptive_avg_pool2d]) - return builder.get_graph_module() + return x_input, builder.get_graph_module() def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) self.assertEqual( len( gm.graph.find_nodes( @@ -2100,8 +2331,24 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2128,7 +2375,7 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: - gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) + x_input, gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) self.assertEqual( len( gm.graph.find_nodes( @@ -2146,9 +2393,25 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ), 0, ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + # Shapes are not multiples of each other, so pass will not trigger p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() - updated_gm = p.call(gm).graph_module + result = p.call(gm) + self.assertFalse(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy (should be same since not modified) + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass", + ) + self.assertEqual( len( updated_gm.graph.find_nodes( @@ -2167,6 +2430,113 @@ def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: ) +class TestReplaceAtenAvgPoolWithCadenceAvgPoolPass(unittest.TestCase): + def _get_aten_avg_pool1d_gm( + self, input_shape: Tuple[int, int, int], kernel_size: int + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool1d = builder.call_operator( + exir_ops.edge.aten.avg_pool1d.default, (x, [kernel_size]) + ) + builder.output([avg_pool1d]) + return x_input, builder.get_graph_module() + + def _get_aten_avg_pool2d_gm( + self, input_shape: Tuple[int, int, int, int], kernel_size: Tuple[int, int] + ) -> tuple[torch.Tensor, torch.fx.GraphModule]: + builder = GraphBuilder() + x_input = torch.randn(*input_shape) + x = builder.placeholder("x", x_input) + avg_pool2d = builder.call_operator( + exir_ops.edge.aten.avg_pool2d.default, (x, list(kernel_size)) + ) + builder.output([avg_pool2d]) + return x_input, builder.get_graph_module() + + def test_replace_aten_avg_pool1d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool1d_gm((1, 32, 64), 3) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool1d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool1d should be replaced with view operations and avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool1d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + # Should have view operations for reshaping + self.assertGreater( + count_node(updated_gm, exir_ops.edge.aten.view_copy.default), + 0, + ) + + def test_replace_aten_avg_pool2d_with_cadence(self) -> None: + x_input, gm = self._get_aten_avg_pool2d_gm((1, 32, 64, 64), (3, 3)) + self.assertEqual( + count_node(gm, exir_ops.edge.aten.avg_pool2d.default), + 1, + ) + self.assertEqual( + count_node(gm, exir_ops.edge.cadence.avg_pool2d.default), + 0, + ) + + # Deepcopy before the pass + gm_before = copy.deepcopy(gm) + + p = ReplaceAtenAvgPoolWithCadenceAvgPoolPass() + result = p.call(gm) + self.assertTrue(result.modified) + updated_gm = result.graph_module + + # Validate numerical accuracy + inputs = [x_input] + validate( + gm_before, + updated_gm, + inputs, + "ReplaceAtenAvgPoolWithCadenceAvgPoolPass", + ) + + # avg_pool2d should be replaced with cadence avg_pool2d + self.assertEqual( + count_node(updated_gm, exir_ops.edge.aten.avg_pool2d.default), + 0, + ) + self.assertEqual( + count_node(updated_gm, exir_ops.edge.cadence.avg_pool2d.default), + 1, + ) + + class TestReplaceLinalgSvdPass(unittest.TestCase): @expand( [