diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 530505f7003..8e602dd17b4 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -30,6 +30,7 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.alias.default, exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, } def __init__(self) -> None: diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 1b74ef1ac65..c17b618a36e 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -301,6 +301,32 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool: ) +@update_features(exir_ops.edge.dim_order_ops._clone_dim_order.default) +def register_clone_dim_order_op(): + # Similar to to_dim_order_copy, _clone_dim_order can be removed as long as the + # operator is not changing the dtype, i.e. the operator call is modifying the dim + # order only. Therefore, check that the input and output dtypes are the same, if so + # the operator is safe to remove. + def check_clone_dim_order_node(node: torch.fx.Node) -> bool: + in_arg = node.args[0] + if not isinstance(in_arg, torch.fx.Node): + return False + + in_tensor = in_arg.meta.get("val", None) + out_tensor = node.meta.get("val", None) + + if in_tensor.dtype != out_tensor.dtype: + return False + + return True + + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_clone_dim_order_node, + ) + + @update_features( [ exir_ops.edge.aten.bmm.default, diff --git a/exir/passes/constant_prop_pass.py b/exir/passes/constant_prop_pass.py index fc93aa1b0ca..7daa3a247e8 100644 --- a/exir/passes/constant_prop_pass.py +++ b/exir/passes/constant_prop_pass.py @@ -295,7 +295,11 @@ def create_constant_nodes_and_return_specs( return name_to_spec_dict -def _update_output_node_and_specs(exported_program: ExportedProgram) -> None: +# add _skip_dim_order to ensure the introduced correct clone node for different dim order schema +# TODO(gasoonjia): only relying on _clone_dim_order once we remove _skip_dim_order option in the EdgeCompileConfig +def _update_output_node_and_specs( + exported_program: ExportedProgram, _skip_dim_order: bool +) -> None: """ Update the output node and output specs in the exported program. In case a constant node is used as output, we replace it with a clone of the constant node. @@ -307,15 +311,19 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None: output_specs = exported_program.graph_signature.output_specs assert len(output_nodes) == len(output_specs) + clone_op = ( + exir_ops.edge.aten.clone.default + if _skip_dim_order + else exir_ops.edge.dim_order_ops._clone_dim_order.default + ) + for i in range(len(output_specs)): out_node = output_nodes[i] if out_node not in updated_constant_placeholders: continue with exported_program.graph.inserting_after(out_node): - new_node = exported_program.graph.call_function( - exir_ops.edge.aten.clone.default, (out_node,) - ) + new_node = exported_program.graph.call_function(clone_op, (out_node,)) assert "val" in out_node.meta new_node.meta["val"] = out_node.meta["val"] output_nodes[i] = new_node @@ -329,6 +337,7 @@ def _update_output_node_and_specs(exported_program: ExportedProgram) -> None: def constant_prop_pass( exported_program: ExportedProgram, custom_skip_targets: Optional[set[EdgeOpOverload]] = None, + _skip_dim_order: bool = True, ) -> ExportedProgram: """ This pass is for constant propagation for Exported Program with lifted parameters, @@ -376,7 +385,7 @@ def constant_prop_pass( new_input_specs.append(name_to_spec_dict[node.name]) exported_program.graph_signature.input_specs = new_input_specs - _update_output_node_and_specs(exported_program) + _update_output_node_and_specs(exported_program, _skip_dim_order=_skip_dim_order) # Cleanup the graph. exported_program.graph.eliminate_dead_code() diff --git a/exir/passes/memory_format_ops_pass.py b/exir/passes/memory_format_ops_pass.py index db597b2b233..421f30960b6 100644 --- a/exir/passes/memory_format_ops_pass.py +++ b/exir/passes/memory_format_ops_pass.py @@ -19,9 +19,6 @@ logger = logging.getLogger(__file__) logger.setLevel(logging.INFO) -# TODO - these passes are too specialized on a single to_copy op. -# We should be able to replace (or revert) any of the dim_order ops in the future. - class MemoryFormatOpsPass(ExportPass): """ @@ -43,7 +40,7 @@ def call_operator(self, op, args, kwargs, meta): # new kwargs with dim_order, and no memory_format for the new op nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable - # get the "to" memory format for the EdgeOp + # get the target memory format for the EdgeOp mem_format = nkwargs.pop("memory_format", torch.contiguous_format) # can always get the shape, assuming rank is specialized diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index 9d56123d83d..781c4d716e4 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -1192,7 +1192,7 @@ def forward(self) -> torch.Tensor: ) edge._edge_programs["forward"] = constant_prop_pass( - edge.exported_program("forward") + edge.exported_program("forward"), _skip_dim_order=False ) # Check (c_lifted_tensor_*) nodes are all replaced by _prop_tensor_constant.