diff --git a/backends/cadence/aot/ref_implementations.py b/backends/cadence/aot/ref_implementations.py index 1128ad3167c..240a5ab28d6 100644 --- a/backends/cadence/aot/ref_implementations.py +++ b/backends/cadence/aot/ref_implementations.py @@ -506,7 +506,9 @@ def variant( raise ValueError("out_shift must be a scalar") if out_shift.dtype != torch.int32: - raise ValueError("out_shift must be an int32") + raise ValueError( + f"out_shift must be an int32. Got {out_shift.dtype} instead" + ) _out_shift = int(out_shift.item()) _out_multiplier = int(out_multiplier[0].item()) diff --git a/backends/cadence/aot/reorder_ops.py b/backends/cadence/aot/reorder_ops.py index 0026c35ed57..857446592ee 100644 --- a/backends/cadence/aot/reorder_ops.py +++ b/backends/cadence/aot/reorder_ops.py @@ -12,7 +12,7 @@ import copy from collections import defaultdict from math import prod -from typing import cast, DefaultDict, List, Set, Tuple +from typing import cast, DefaultDict, List, Tuple import torch import torch.fx @@ -21,6 +21,7 @@ CadencePassAttribute, get_overload_packet, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects._ops import ops as exir_ops @@ -484,7 +485,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class SinkOpsCloserToUsePass(ExportPass): +class SinkOpsCloserToUsePass(RemoveOrReplacePassInterface): """ Assume that the dequantize op D = dequantize(I) has only a single user. If the current graph looks like @@ -504,47 +505,38 @@ class SinkOpsCloserToUsePass(ExportPass): much smaller size. """ - sinkable_ops: Set[EdgeOpOverload] = { - exir_ops.edge.aten.dequantize, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor, - exir_ops.edge.quantized_decomposed.dequantize_per_channel, - exir_ops.edge.cadence.dequantize_per_tensor, - } - - def sink_ops_closer_to_use(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - # We are only interested in sinkable nodes - sinkable_nodes = [ - node - for node in graph.nodes - if isinstance(node.target, EdgeOpOverload) - and get_edge_overload_packet(node.target) in self.sinkable_ops + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.aten.dequantize, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, + exir_ops.edge.cadence.dequantize_per_tensor.default, ] - for node in sinkable_nodes: - # The sinkable node must have a single user - users = list(node.users.keys()) - if len(users) != 1: - continue - # Insert the dequant node just before its user - with graph.inserting_before(users[0]): - new_node = graph.call_function( - node.target, args=node.args, kwargs=node.kwargs - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - graph.erase_node(node) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + # The sinkable node must have a single user + users = list(node.users.keys()) + if len(users) != 1: + return False - graph_module.recompile() + # Insert the dequant node just before its user + with node.graph.inserting_before(users[0]): + # Target is guaranteed to be a callable since it's from our targets list + target_callable = node.target + assert callable(target_callable), "Target must be callable" + new_node = node.graph.call_function( + target_callable, args=node.args, kwargs=node.kwargs + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + node.graph.erase_node(node) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.sink_ops_closer_to_use(graph_module) - result = super().call(graph_module) - return result + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class HoistOpsCloserToDefPass(ExportPass): +class HoistOpsCloserToDefPass(RemoveOrReplacePassInterface): """ Assume that the input I to a quantize op Q = quantize(I) has only a single use, the quantize node itself. @@ -565,77 +557,71 @@ class HoistOpsCloserToDefPass(ExportPass): much smaller size. The same transformation also applies to slice/select op. """ - hoistable_ops: Set[EdgeOpOverload] = { - exir_ops.edge.quantized_decomposed.quantize_per_tensor, - exir_ops.edge.cadence.quantize_per_tensor, - exir_ops.edge.aten.slice_copy, - exir_ops.edge.aten.select_copy, - } - - def hoist_ops_closer_to_def(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - # We are only interested in hoistable nodes - hoistable_nodes = [ - node - for node in graph.nodes - if isinstance(node.target, EdgeOpOverload) - and get_edge_overload_packet(node.target) in self.hoistable_ops + @property + def targets(self) -> list[EdgeOpOverload]: + return [ + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.cadence.quantize_per_tensor.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.select_copy.int, ] - for node in hoistable_nodes: - def_node = node.args[0] - if not isinstance(def_node, torch.fx.Node): - continue - # The def node must have a single user - users = list(def_node.users.keys()) - if len(users) != 1: - continue - # Get the node args as list - args = list(node.args) + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + def_node = node.args[0] + if not isinstance(def_node, torch.fx.Node): + return False - # If the graph has placeholders, we do not want to hoist above the - # last placeholder. Otherwise we will shrink the live range of the - # def_node considerably, which could lead to reuse of input memory. - def_node = ( - get_placeholders(graph)[-1] - if def_node.op == "placeholder" - else def_node - ) + # The def node must have a single user + users = list(def_node.users.keys()) + if len(users) != 1: + return False - # If the node is quantize_per_channel, we need to hoist the scale - # and zero_point tensors as well. - if ( - node.target - == exir_ops.edge.quantized_decomposed.quantize_per_channel.default - ): - scale, zero_point = args[1], args[2] - with graph.inserting_after(def_node): - zero_point_copy = graph.node_copy(zero_point) - scale_copy = graph.node_copy(scale) - args[1], args[2] = scale_copy, zero_point_copy - def_node = zero_point_copy - - # Insert the quant node just after def_node - with graph.inserting_after(def_node): - new_node = graph.call_function( - node.target, args=tuple(args), kwargs=node.kwargs - ) - new_node.meta = node.meta - node.replace_all_uses_with(new_node) - graph.erase_node(node) + # Get the node args as list + args = list(node.args) - # Eliminate dead code - graph_module.recompile() - graph_module.graph.eliminate_dead_code() + # If the graph has placeholders, we do not want to hoist above the + # last placeholder. Otherwise we will shrink the live range of the + # def_node considerably, which could lead to reuse of input memory. + insertion_point = ( + get_placeholders(node.graph)[-1] + if def_node.op == "placeholder" + else def_node + ) - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.hoist_ops_closer_to_def(graph_module) - result = super().call(graph_module) - return result + # If the node is quantize_per_channel, we need to hoist the scale + # and zero_point tensors as well. + if ( + node.target + == exir_ops.edge.quantized_decomposed.quantize_per_channel.default + ): + scale, zero_point = args[1], args[2] + if not isinstance(scale, torch.fx.Node) or not isinstance( + zero_point, torch.fx.Node + ): + return False + with node.graph.inserting_after(insertion_point): + zero_point_copy = node.graph.node_copy(zero_point) + scale_copy = node.graph.node_copy(scale) + args[1], args[2] = scale_copy, zero_point_copy + insertion_point = zero_point_copy + + # Insert the quant node just after insertion_point + with node.graph.inserting_after(insertion_point): + # Target is guaranteed to be a callable since it's from our targets list + target_callable = node.target + assert callable(target_callable), "Target must be callable" + new_node = node.graph.call_function( + target_callable, args=tuple(args), kwargs=node.kwargs + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + node.graph.erase_node(node) + + return True @register_cadence_pass(CadencePassAttribute(opt_level=1)) -class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(ExportPass): +class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(RemoveOrReplacePassInterface): """ A common pattern seen in transformer models. If the consumer of permute is a view op, swap their order so permute is below view. @@ -649,14 +635,16 @@ class PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(ExportPass): mean the view_copy is normalized from squeeze or unsqueeze. """ - def __init__(self): - super().__init__() - self.graph_module = None + @property + def targets(self) -> list[EdgeOpOverload]: + return [exir_ops.edge.aten.permute_copy.default] # If list1 and list2 are same (same values and in same order) except # list1 has one more element with value of 1. Return index of the extra 1. # Otherwise return -1. - def check_if_shapes_differ_in_single_dim_of_size_1(self, list1, list2) -> int: + def check_if_shapes_differ_in_single_dim_of_size_1( + self, list1: List, list2: List + ) -> int: if len(list1) != len(list2) + 1: return -1 for i in range(len(list2)): @@ -672,7 +660,104 @@ def check_if_shapes_differ_in_single_dim_of_size_1(self, list1, list2) -> int: else: return -1 - def insert_nodes( + def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: + users = list(node.users.keys()) + # Transform only for pattern permute_copy->view_copy, and + # view_copy op is the only user of permute_copy. + if len(users) != 1 or users[0].target not in ( + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.view.default, + ): + return False + + # If the permute_node/view_node was newly added to the + # graph, it may not have the meta["val"] FakeTensor. + # Skip in this case. + if node.meta.get("val") is None: + return False + + permute_node_shape = [*cast(list, get_shape(node.graph.owning_module, node))] + + permute_dims = cast(list, node.args[1]) + view_node = users[0] + + if view_node.meta.get("val") is None: + return False + + view_node_shape = [*cast(list, get_shape(node.graph.owning_module, view_node))] + + pred = node.args[0] + if not isinstance(pred, torch.fx.Node) or pred.meta.get("val") is None: + return False + + pred_shape = [*cast(list, get_shape(node.graph.owning_module, pred))] + + # Handle three cases + # 1. view_node_shape is almost same as permute_node_shape + # except the view_node has one more dim somewhere + # and the extra dim has value of 1. + # 2. view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # 3. view_node_shape is the same as permute_node_shape. + + if len(permute_node_shape) + 1 == len(view_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + view_node_shape, permute_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except it has one more dim somewhere + # and the extra dim has value of 1. + new_view_shape = copy.deepcopy(pred_shape) + new_view_shape.insert(index, 1) + new_permute_dims = [x + 1 if x >= index else x for x in permute_dims] + new_permute_dims.insert(index, index) + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif len(view_node_shape) + 1 == len(permute_node_shape): + index = self.check_if_shapes_differ_in_single_dim_of_size_1( + permute_node_shape, view_node_shape + ) + if index != -1: + # view_node_shape is almost same as permute_node_shape + # except permute_node_shape has one more dim somewhere + # and the extra dim has value of 1. + # Convert permute_dims to list of ints + index_to_remove = permute_dims[index] + new_view_shape = copy.deepcopy(pred_shape) + del new_view_shape[index_to_remove] + new_permute_dims = [ + x - 1 if x > index_to_remove else x for x in permute_dims + ] + del new_permute_dims[index] + self._insert_nodes( + node.graph, + pred, + node, + view_node, + new_view_shape, + new_permute_dims, + ) + return True + + elif permute_node_shape == view_node_shape: + # view_node_shape is the same as permute_node_shape + # Replace the uses of view_node with permute_node + view_node.replace_all_uses_with(node) + return True + + return False + + def _insert_nodes( self, graph: torch.fx.Graph, pred: torch.fx.Node, @@ -680,16 +765,22 @@ def insert_nodes( view_node: torch.fx.Node, new_view_shape: List, new_permute_dims: List, - ): + ) -> None: with graph.inserting_after(view_node): + # Target is guaranteed to be a callable since it's from the graph + view_target = view_node.target + assert callable(view_target), "View target must be callable" new_view_node = graph.call_function( - view_node.target, # pyre-fixme[6] + view_target, args=(pred, new_view_shape), ) with graph.inserting_after(new_view_node): + # Target is guaranteed to be a callable since it's from our targets list + permute_target = permute_node.target + assert callable(permute_target), "Permute target must be callable" new_permute_node = graph.call_function( - permute_node.target, # pyre-fixme[6] + permute_target, args=(new_view_node, new_permute_dims), ) new_permute_node.meta = view_node.meta @@ -699,125 +790,22 @@ def insert_nodes( graph.erase_node(view_node) graph.erase_node(permute_node) - # flake8: noqa 'PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView.postpone_permute_op' is too complex (13) - def postpone_permute_op(self, graph_module: torch.fx.GraphModule): - packet_to_overload_map = { - exir_ops.edge.aten.permute_copy: "default", - } - graph = graph_module.graph - changed = True - modified = False - # Loop iteratively until no more changes are made - while changed: - changed = False - for permute_node in graph.nodes: - permute_overload_packet = get_overload_packet(permute_node.target) - if permute_overload_packet not in packet_to_overload_map.keys(): - continue - - users = list(permute_node.users.keys()) - # Transform only for pattern permute_copy->view_copy, and - # view_copy op is the only user of permute_copy. - if len(users) == 1 and users[0].target in ( - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.view.default, - ): - # If the permute_node/view_node was newly added to the - # graph, it may not have the meta["val"] FakeTensor. - # Skip in this case. - if permute_node.meta.get("val") is None: - continue - permute_node_shape = [ - *cast(list, get_shape(graph_module, permute_node)) - ] - permute_dims = permute_node.args[1] - view_node = users[0] - if view_node.meta.get("val") is None: - continue - view_node_shape = [*cast(list, get_shape(graph_module, view_node))] - pred = permute_node.args[0] - if pred.meta.get("val") is None: - continue - pred_shape = [*cast(list, get_shape(graph_module, pred))] - # Handle two cases - # 1. view_node_shape is almost same as permute_node_shape - # except the view_node has one more dim somewhere - # and the extra dim has value of 1. - # 2. view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - # 3. view_node_shape is the same as permute_node_shape. - if len(permute_node_shape) + 1 == len(view_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - view_node_shape, permute_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except it has one more dim somewhere - # and the extra dim has value of 1. - new_view_shape = copy.deepcopy(pred_shape) - new_view_shape.insert(index, 1) - new_permute_dims = [ - x + 1 if x >= index else x for x in permute_dims - ] - new_permute_dims.insert(index, index) - self.insert_nodes( - graph, - pred, - permute_node, - view_node, - new_view_shape, - new_permute_dims, - ) - changed = True - modified = True - elif len(view_node_shape) + 1 == len(permute_node_shape): - index = self.check_if_shapes_differ_in_single_dim_of_size_1( - permute_node_shape, view_node_shape - ) - if index != -1: - # view_node_shape is almost same as permute_node_shape - # except permute_node_shape has one more dim somewhere - # and the extra dim has value of 1. - index_to_remove = permute_dims[index] - new_view_shape = copy.deepcopy(pred_shape) - del new_view_shape[index_to_remove] - new_permute_dims = [ - x - 1 if x > index_to_remove else x - for x in permute_dims - ] - del new_permute_dims[index] - self.insert_nodes( - graph, - pred, - permute_node, - view_node, - new_view_shape, - new_permute_dims, - ) - changed = True - modified = True - elif permute_node_shape == view_node_shape: - # view_node_shape is the same as permute_node_shape - # Replace the uses of view_node with permute_node - view_node.replace_all_uses_with(permute_node) - changed = True - modified = True - - graph_module.recompile() - return modified - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - self.graph_module = graph_module + # This pass needs to iterate until convergence because postponing + # one permute may enable postponing another in a chain iter_count = 0 - modified = True - - while modified and iter_count <= 3: - modified = self.postpone_permute_op(self.graph_module) - self.graph_module = super().call(self.graph_module).graph_module + local_modified = False + overall_modified = False + while local_modified or iter_count == 0: + result = super().call(graph_module) + local_modified = result.modified + overall_modified |= local_modified + graph_module = result.graph_module iter_count += 1 + if iter_count == 4: + break - return super().call(self.graph_module) + return PassResult(graph_module, overall_modified) class CommonReorderPasses: diff --git a/backends/cadence/aot/simplify_ops.py b/backends/cadence/aot/simplify_ops.py index 92c14cb0f5d..5ee2a2f01ef 100644 --- a/backends/cadence/aot/simplify_ops.py +++ b/backends/cadence/aot/simplify_ops.py @@ -12,18 +12,21 @@ import sys from typing import Optional +import torch from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, + EdgeOpOverload, register_cadence_pass, + RemoveOrReplacePassInterface, ) from executorch.backends.cadence.aot.utils import rebind from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx import Node @register_cadence_pass(CadencePassAttribute(opt_level=0)) -class SimplifySliceOpPass(ExportPass): +class SimplifySliceOpPass(RemoveOrReplacePassInterface): """ Simplify the start and end indices of slice and slice_scatter ops. """ @@ -62,66 +65,149 @@ def adjust_slice_range( # Return the adjusted start and end indices return (start_val, end_val) - def call_operator(self, op, args, kwargs, meta): - # We are only interested in slice_copy or slice_scatter ops - if op not in { + @property + def targets(self) -> list[EdgeOpOverload]: + return [ exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.slice_scatter.default, - }: - return super().call_operator(op, args, kwargs, meta) + ] + def maybe_remove_or_replace(self, node: Node) -> bool: # Check if it is a slice_scatter op or not. The slice_scatter op has # an extra src argument at index 1. - slice_scatter = op == exir_ops.edge.aten.slice_scatter.default - # Parse the arguments - # Extract the tensor to be sliced, and the slicing dimension - in_tensor = args[0].to_tensor() - dim = args[1 + slice_scatter] if len(args) > 1 + slice_scatter else 0 + slice_scatter = node.target == exir_ops.edge.aten.slice_scatter.default + + # Get input tensor metadata + input_node = node.args[0] + if not isinstance(input_node, Node) or "val" not in input_node.meta: + return False + + in_tensor = input_node.meta["val"] + + # Extract the slicing dimension + dim_idx = 1 + (1 if slice_scatter else 0) + dim = node.args[dim_idx] if len(node.args) > dim_idx else 0 + if not isinstance(dim, int): + return False + # Make dim non-negative + original_dim = dim dim = dim if dim >= 0 else dim + in_tensor.dim() length = in_tensor.size(dim) # Get the adjusted start and end indices - start_val = args[2 + slice_scatter] if len(args) > 2 + slice_scatter else None - end_val = args[3 + slice_scatter] if len(args) > 3 + slice_scatter else None - step = args[4 + slice_scatter] if len(args) > 4 + slice_scatter else 1 - (start_val, end_val) = self.adjust_slice_range(length, start_val, end_val, step) - - # If the start_val is geq end_val, then we can return an empty tensor - # for slice op, or input for slice_scatter op. - if start_val >= end_val and slice_scatter: - return args[0] - if start_val >= end_val: - empty_shape = [x for x in in_tensor.shape if x != 0] - empty_shape[dim] = 0 - return super().call_operator( - exir_ops.edge.aten.full.default, - (tuple(empty_shape), 0), - {"dtype": in_tensor.dtype}, - meta, - ) - - # Create new args - new_args = ( - (args[0],) - + ((args[1],) if slice_scatter else ()) - + (dim, start_val, end_val, step) + start_idx = 2 + (1 if slice_scatter else 0) + end_idx = 3 + (1 if slice_scatter else 0) + step_idx = 4 + (1 if slice_scatter else 0) + + start_val = node.args[start_idx] if len(node.args) > start_idx else None + end_val = node.args[end_idx] if len(node.args) > end_idx else None + step = node.args[step_idx] if len(node.args) > step_idx else 1 + + # Validate types + if start_val is not None and not isinstance(start_val, int): + return False + if end_val is not None and not isinstance(end_val, int): + return False + if not isinstance(step, int): + return False + + # Get the adjusted start and end indices + original_start = start_val + original_end = end_val + (adjusted_start, adjusted_end) = self.adjust_slice_range( + length, start_val, end_val, step + ) + + # Check if anything changed + nothing_changed = ( + adjusted_start == original_start + and adjusted_end == original_end + and dim == original_dim ) - return super().call_operator(op, new_args, kwargs, meta) + if nothing_changed: + return False + + # Replace the node based on the adjusted range + with node.graph.inserting_before(node): + if adjusted_start >= adjusted_end and slice_scatter: + # For slice_scatter with empty range, return the input + node.replace_all_uses_with(input_node) + elif adjusted_start >= adjusted_end: + # For slice with empty range, create an empty tensor + empty_shape = list(in_tensor.shape) + empty_shape[dim] = 0 + new_node = node.graph.call_function( + exir_ops.edge.aten.full.default, + (tuple(empty_shape), 0), + {"dtype": in_tensor.dtype}, + ) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + else: + # Create new args with simplified indices + if slice_scatter: + new_args = ( + node.args[0], # input + node.args[1], # src + dim, + adjusted_start, + adjusted_end, + step, + ) + else: + new_args = ( + node.args[0], # input + dim, + adjusted_start, + adjusted_end, + step, + ) + # Target is guaranteed to be a callable since it's from our targets list + target_callable = node.target + assert callable(target_callable), "Target must be callable" + new_node = node.graph.call_function(target_callable, new_args, {}) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + + node.graph.erase_node(node) + return True @register_cadence_pass(CadencePassAttribute(opt_level=0)) class BindOptionalArgsPass(ExportPass): """Bind all optional args and kwargs.""" - def call_operator(self, op, args, kwargs, meta): - if not isinstance(op, EdgeOpOverload): - return super().call_operator(op, args, kwargs, meta) - - if (updated_args := rebind(op, args, kwargs)) is not None: - args, kwargs = updated_args - - return super().call_operator(op, args, kwargs, meta) + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + """ + Bind all optional args and kwargs for EdgeOpOverload operations. + Only reports modified=True if arguments were actually changed. + """ + modified = False + + for node in graph_module.graph.nodes: + if node.op != "call_function": + continue + if not isinstance(node.target, EdgeOpOverload): + continue + + # Try to rebind the args/kwargs to populate optional arguments + updated_args = rebind(node.target, tuple(node.args), dict(node.kwargs)) + if updated_args is None: + # No schema matched or no changes needed + continue + + new_args, new_kwargs = updated_args + # Check if anything actually changed + if new_args != node.args or new_kwargs != node.kwargs: + node.args = new_args + node.kwargs = new_kwargs + modified = True + + if modified: + graph_module.recompile() + + return PassResult(graph_module, modified) # This class encapsulates all the functions that simplify the op's args diff --git a/backends/cadence/aot/tests/test_reorder_ops_passes.py b/backends/cadence/aot/tests/test_reorder_ops_passes.py index 50f5ca32c47..18b081ae85b 100644 --- a/backends/cadence/aot/tests/test_reorder_ops_passes.py +++ b/backends/cadence/aot/tests/test_reorder_ops_passes.py @@ -6,13 +6,12 @@ # pyre-strict - +import copy import unittest from typing import cast import executorch.backends.cadence.aot.ops_registrations # noqa import torch - from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass from executorch.backends.cadence.aot.graph_builder import GraphBuilder from executorch.backends.cadence.aot.pass_utils import ( @@ -30,17 +29,70 @@ SinkOpsCloserToUsePass, ) from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import PassResult +from executorch.exir.pass_base import PassBase, PassResult +from torch.utils import _pytree as pytree + + +def transform_and_check_numerics( + original_graph: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_to_run: PassBase, + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> PassResult: + """Run a graph transformation and validate numerical equivalence. + + Args: + original_graph: The original graph module before transformation + inputs: Input tensors to run through both graphs + pass_to_run: The pass to apply to the graph + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + + Returns: + The PassResult from the transformation + """ + # Deepcopy to preserve original for comparison + gm_before = copy.deepcopy(original_graph) + + # Run the transformation + result = cast(PassResult, pass_to_run.call(original_graph)) + + # Validate numerical equivalence + gm_before.eval() + result.graph_module.eval() + with torch.no_grad(): + orig_out = gm_before(*inputs) + mod_out = result.graph_module(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) + + return result class TestReorderPasses(unittest.TestCase): def test_sink_dequantize(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(32, 6, dtype=torch.float32)) - y = builder.placeholder("y", torch.randn(32, 6, dtype=torch.float32)) - weights = builder.placeholder( - "weights", torch.randint(-128, 127, (6, 8), dtype=torch.int8) - ) + x_data = torch.randn(32, 6, dtype=torch.float32) + y_data = torch.randn(32, 6, dtype=torch.float32) + weight_data = torch.randint(-128, 127, (8, 6), dtype=torch.int8) + x = builder.placeholder("x", x_data) + y = builder.placeholder("y", y_data) + weights = builder.placeholder("weights", weight_data) x_quantized = builder.call_operator( op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, args=(x, 0.02252197265625, 20, -128, 127, torch.int8), @@ -52,6 +104,7 @@ def test_sink_dequantize(self) -> None: full = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_1 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -63,11 +116,13 @@ def test_sink_dequantize(self) -> None: ) full_3 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([1], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) full_4 = builder.call_operator( op=exir_ops.edge.aten.full.default, args=([1], -7), + kwargs={"dtype": torch.int32}, ) full_5 = builder.call_operator( op=exir_ops.edge.aten.full.default, @@ -79,7 +134,8 @@ def test_sink_dequantize(self) -> None: ) full_7 = builder.call_operator( op=exir_ops.edge.aten.full.default, - args=([1], 0.0), + args=([1], 0), + kwargs={"dtype": torch.int32}, ) quantized_linear = builder.call_operator( op=exir_ops.edge.cadence.quantized_linear.default, @@ -107,8 +163,14 @@ def test_sink_dequantize(self) -> None: ) builder.output([cat]) original_graph = builder.get_graph_module() - p = SinkOpsCloserToUsePass() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = transform_and_check_numerics( + original_graph, + (x_data, y_data, weight_data), + SinkOpsCloserToUsePass(), + "SinkOpsCloserToUsePass", + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it self.assertTrue( @@ -349,12 +411,11 @@ def test_postpone_dequantize1(self) -> None: def test_postpone_dequantize_branched(self) -> None: builder = GraphBuilder() - x = builder.placeholder( - "x", torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) - ) - p_linear_weight = builder.placeholder( - "weights", torch.randint(-128, 127, (3, 3), dtype=torch.int8) - ) + x_data = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8) + weights_data = torch.randn([3, 3], dtype=torch.float32) + + x = builder.placeholder("x", x_data) + p_linear_weight = builder.placeholder("weights", weights_data) quantized_decomposed_dequantize_per_tensor_default = builder.call_operator( op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, args=(x, 0.1, 10, 0, 255, torch.uint8), @@ -404,9 +465,15 @@ def test_postpone_dequantize_branched(self) -> None: ) builder.output([aten_mm_default, aten_mm_default_1, aten_mm_default_2]) original_graph = builder.get_graph_module() - p = PostponeDequantizeOpBelowUseChainPass() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = transform_and_check_numerics( + original_graph, + (x_data, weights_data), + PostponeDequantizeOpBelowUseChainPass(), + "PostponeDequantizeOpBelowUseChainPass", + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module + # Asset that the dequant node was split into 4, one per branch self.assertEqual( count_node( @@ -436,7 +503,8 @@ def test_postpone_dequantize_branched(self) -> None: # 4d -> permute -> 4d -> view -> 3d def test_permute3_view4_chains(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 1, 768)) + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) aten_view_copy_default = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [3, 12, 64]), @@ -455,8 +523,14 @@ def test_permute3_view4_chains(self) -> None: ) builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() - converted_graph = cast(PassResult, p(original_graph)).graph_module + result = transform_and_check_numerics( + original_graph, + (x_data,), + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module converted_graph.graph.eliminate_dead_code() # Assert the order becomes view, view, permute, permute nodes = get_compute_nodes_in_gm(converted_graph) @@ -469,7 +543,8 @@ def test_permute3_view4_chains(self) -> None: # 3d -> permute -> 3d -> view -> 4d def test_permute4_view3_chains(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 1, 768)) + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) aten_view_copy_default = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]), @@ -489,9 +564,14 @@ def test_permute4_view3_chains(self) -> None: builder.output([aten_permute_copy_default_1]) original_graph = builder.get_graph_module() - p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = transform_and_check_numerics( + original_graph, + (x_data,), + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + self.assertTrue(result.modified) + converted_graph = result.graph_module # Assert the order becomes view, view, permute, permute nodes = get_compute_nodes_in_gm(converted_graph) @@ -506,7 +586,8 @@ def test_permute4_view3_chains(self) -> None: # size is 1 (this is ok), but also changes the size of the dimensions (not ok). def test_permute_view_chains_neg(self) -> None: builder = GraphBuilder() - x = builder.placeholder("x", torch.randn(3, 1, 768)) + x_data = torch.randn(3, 1, 768) + x = builder.placeholder("x", x_data) aten_view_copy_default = builder.call_operator( op=exir_ops.edge.aten.view_copy.default, args=(x, [1, 3, 12, 64]), @@ -527,9 +608,14 @@ def test_permute_view_chains_neg(self) -> None: original_graph = builder.get_graph_module() # Performing transform (nothing should happen) - p = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView() - converted_graph = cast(PassResult, p(original_graph)).graph_module - converted_graph.graph.eliminate_dead_code() + result = transform_and_check_numerics( + original_graph, + (x_data,), + PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), + "PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView", + ) + self.assertFalse(result.modified) + converted_graph = result.graph_module # Assert the order is still view, permute, view, permute nodes = get_compute_nodes_in_gm(converted_graph) diff --git a/backends/cadence/aot/tests/test_simplify_ops_passes.py b/backends/cadence/aot/tests/test_simplify_ops_passes.py index f26fe897e1e..a20b7fd535e 100644 --- a/backends/cadence/aot/tests/test_simplify_ops_passes.py +++ b/backends/cadence/aot/tests/test_simplify_ops_passes.py @@ -6,7 +6,7 @@ # pyre-strict - +import copy import unittest from typing import cast, Optional, Tuple @@ -20,7 +20,59 @@ ) from executorch.backends.cadence.aot.typing_stubs import expand from executorch.exir.dialects._ops import ops as exir_ops -from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils import _pytree as pytree + + +def transform_and_check_numerics( + original_graph: torch.fx.GraphModule, + inputs: tuple[torch.Tensor, ...] | list[torch.Tensor], + pass_to_run: PassBase, + pass_name: str, + rtol: float = 1e-5, + atol: float = 1e-6, +) -> PassResult: + """Run a graph transformation and validate numerical equivalence. + + Args: + original_graph: The original graph module before transformation + inputs: Input tensors to run through both graphs + pass_to_run: The pass to apply to the graph + pass_name: Name of the pass being validated (for error messages) + rtol: Relative tolerance for allclose comparison + atol: Absolute tolerance for allclose comparison + + Returns: + The PassResult from the transformation + """ + # Deepcopy to preserve original for comparison + gm_before = copy.deepcopy(original_graph) + + # Run the transformation + result = cast(PassResult, pass_to_run.call(original_graph)) + + # Validate numerical equivalence + gm_before.eval() + result.graph_module.eval() + with torch.no_grad(): + orig_out = gm_before(*inputs) + mod_out = result.graph_module(*inputs) + + flat_orig_out, _ = pytree.tree_flatten(orig_out) + flat_mod_out, _ = pytree.tree_flatten(mod_out) + + # Check that outputs match within tolerance + for i, (orig_tensor, mod_tensor) in enumerate(zip(flat_orig_out, flat_mod_out)): + if not torch.allclose(orig_tensor, mod_tensor, rtol=rtol, atol=atol): + max_diff = torch.max(torch.abs(orig_tensor - mod_tensor)).item() + raise AssertionError( + f"Pass validation failed for pass {pass_name}. " + f"Output tensor {i} differs by max {max_diff:.6e}. " + f"Expected rtol={rtol}, atol={atol}. " + f"Original output: {orig_tensor}, Modified output: {mod_tensor}" + ) + + return result class TestSimplifyOpsPasses(unittest.TestCase): @@ -46,8 +98,11 @@ def test_simplify_slice_scatter_op( op=exir_ops.edge.aten.slice_scatter.default, args=(x, y, dim, start, end, step), ) - p = SimplifySliceOpPass() - gm = cast(PassResult, p(gm)).graph_module + result = transform_and_check_numerics( + gm, (x, y), SimplifySliceOpPass(), "SimplifySliceOpPass" + ) + self.assertTrue(result.modified) + gm = result.graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_scatter.default), 0) @expand( @@ -76,8 +131,11 @@ def test_simplify_slice_op( step, ), ) - p = SimplifySliceOpPass() - gm = cast(PassResult, p(gm)).graph_module + result = transform_and_check_numerics( + gm, (x,), SimplifySliceOpPass(), "SimplifySliceOpPass" + ) + self.assertTrue(result.modified) + gm = result.graph_module self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 0) self.assertEqual(count_node(gm, exir_ops.edge.aten.full.default), 1) @@ -92,7 +150,11 @@ def test_simplify_slice_op_args(self) -> None: original_slice_copy = list(gm.graph.nodes)[1] self.assertEqual(original_slice_copy.args[1:], (1,)) self.assertEqual(original_slice_copy.kwargs, {"end": 3}) - gm = BindOptionalArgsPass().call(gm).graph_module + result = transform_and_check_numerics( + gm, (x,), BindOptionalArgsPass(), "BindOptionalArgsPass" + ) + self.assertTrue(result.modified) + gm = result.graph_module modified_slice_copy = list(gm.graph.nodes)[1] self.assertEqual(modified_slice_copy.args[1:], (1, None, 3, 1)) self.assertEqual(modified_slice_copy.kwargs, {})