diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 3f6588d84ad..cafeedbd5da 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -6,6 +6,8 @@ # pyre-strict +from copy import deepcopy + import executorch.backends.vulkan.custom_ops_lib # noqa import torch @@ -69,9 +71,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool: exir_ops.edge.et_vk.prepack.default, (node,), ) - prepack_node.meta["spec"] = node.meta["spec"] + # This pass assumes that the SpecPropPass() has already been applied + assert "spec" in node.meta + # Validate that the original node is marked as a constant. Constant tensors + # do not participate in memory planning. + assert node.meta["spec"].const + prepack_node.meta["val"] = node.meta["val"] + prepack_node.meta["spec"] = deepcopy(node.meta["spec"]) # Set the mem_obj_id to -1 to indicate that this node requires a dedicated - # memory object. This pass must be executed AFTER the memory planning pass. + # memory object. prepack_node.meta["spec"].mem_obj_id = -1 node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 0e116ad2c4c..96eee198f4d 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -17,8 +17,10 @@ from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform -from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform -from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes +from executorch.backends.vulkan._passes import ( + insert_prepack_nodes, + RemoveLocalScalarDenseOpsTransform, +) from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_serialize import ( @@ -32,6 +34,7 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.pass_base import ExportPass, PassBase from executorch.exir.passes import MemoryPlanningPass, SpecPropPass @@ -46,6 +49,35 @@ DEFAULT_DEBUG_HANDLE = 65535 +# pyre-ignore +def apply_passes(program: ExportedProgram, passes) -> ExportedProgram: + for p in passes: + + if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase): + new_gm = program.graph_module + # This is a workaround to allow the memory planning pass to work without + # having to first apply ToOutVarPass(). See the `greedy()` function in + # `exir.memory_planning`; if this attribute isn't set, assertions in + # `collect_spec_from_nodes()` will fail. + if isinstance(p, MemoryPlanningPass): + new_gm.encounter_to_out_var_failure = True + + new_gm_res = p(new_gm) + assert new_gm_res is not None + new_gm = new_gm_res.graph_module + + # See the application of this function in exir/program/_program.py for more + # details on why this step is necessary. + if isinstance(p, SpecPropPass): + p.update_placeholder_tensor_specs(program, new_gm) + + _copy_module(program.graph_module, new_gm) + else: + program = p(program) + + return program + + @final class VulkanBackend(BackendDetails): @classmethod @@ -57,35 +89,44 @@ def preprocess( # noqa: C901 ) -> PreprocessResult: program = unsafe_remove_auto_functionalized_pass(program) - passes = [ - RemoveCloneOpsTransform(), - AddmmToLinearTransform(), - FuseDequantLinearPass(), - FuseViewCopyTransform(), - FuseBatchNormWithConvPass(program), - FuseClampPass(), - SpecPropPass(), - ConstraintBasedSymShapeEvalPass(), - RemoveLocalScalarDenseOpsTransform(), - MemoryPlanningPass(), - ] - - new_gm = program.graph_module - - for p in passes: - # This is a workaround to allow the memory planning pass to work without - # having to first apply ToOutVarPass(). See the `greedy()` function in - # `exir.memory_planning`; if this attribute isn't set, assertions in - # `collect_spec_from_nodes()` will fail. - if isinstance(p, MemoryPlanningPass): - new_gm.encounter_to_out_var_failure = True - new_gm_res = p(new_gm) - assert new_gm_res is not None - new_gm = new_gm_res.graph_module + # First, apply passes that fuse/remove operators to consolidate the graph + # structure but still preserve an "ATen-compliant" graph structure (i.e. all + # arguments to ATen operators must match the ATen function schema). + program = apply_passes( + program, + [ + RemoveCloneOpsTransform(), + AddmmToLinearTransform(), + FuseDequantLinearPass(), + FuseViewCopyTransform(), + FuseBatchNormWithConvPass(program), + FuseClampPass(), + ], + ) - _copy_module(program.graph_module, new_gm) + # Next annotate tensor nodes with TensorSpec structs which is needed for dynamic + # shapes and memory planning. Until this point, the graph must be ATen compliant + # because SpecPropPass will be calling the underlying ATen operators during its + # execution. + program = apply_passes(program, [SpecPropPass()]) + + # Apply graph transforms which either require `TensorSpec`s to have been created + # or would create an non ATen compliant graph structure. + program = apply_passes( + program, + [ + # Since this pass may replace a scalar argument with a tensor argument, + # this pass may result in a non ATen compliant graph structure. + RemoveLocalScalarDenseOpsTransform(), + insert_prepack_nodes, + ], + ) - program = insert_prepack_nodes(program) + # Finally, apply dynamic shape passes and memory planning pass. These passes + # must be applied only when the graph structure is finalized. + program = apply_passes( + program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()] + ) graph_builder = VkGraphBuilder( program, DelegateMappingBuilder(generated_identifiers=True)