Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions backends/vulkan/_passes/insert_prepack_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

# pyre-strict

from copy import deepcopy

import executorch.backends.vulkan.custom_ops_lib # noqa

import torch
Expand Down Expand Up @@ -69,9 +71,15 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
exir_ops.edge.et_vk.prepack.default,
(node,),
)
prepack_node.meta["spec"] = node.meta["spec"]
# This pass assumes that the SpecPropPass() has already been applied
assert "spec" in node.meta
# Validate that the original node is marked as a constant. Constant tensors
# do not participate in memory planning.
assert node.meta["spec"].const
prepack_node.meta["val"] = node.meta["val"]
prepack_node.meta["spec"] = deepcopy(node.meta["spec"])
# Set the mem_obj_id to -1 to indicate that this node requires a dedicated
# memory object. This pass must be executed AFTER the memory planning pass.
# memory object.
prepack_node.meta["spec"].mem_obj_id = -1
node.replace_all_uses_with(prepack_node, lambda x, y=prepack_node: x != y)

Expand Down
99 changes: 70 additions & 29 deletions backends/vulkan/vulkan_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform

from executorch.backends.vulkan._passes import RemoveLocalScalarDenseOpsTransform
from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes
from executorch.backends.vulkan._passes import (
insert_prepack_nodes,
RemoveLocalScalarDenseOpsTransform,
)

from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
Expand All @@ -32,6 +34,7 @@
PreprocessResult,
)
from executorch.exir.backend.utils import DelegateMappingBuilder
from executorch.exir.pass_base import ExportPass, PassBase

from executorch.exir.passes import MemoryPlanningPass, SpecPropPass

Expand All @@ -46,6 +49,35 @@
DEFAULT_DEBUG_HANDLE = 65535


# pyre-ignore
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
for p in passes:

if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
new_gm = program.graph_module
# This is a workaround to allow the memory planning pass to work without
# having to first apply ToOutVarPass(). See the `greedy()` function in
# `exir.memory_planning`; if this attribute isn't set, assertions in
# `collect_spec_from_nodes()` will fail.
if isinstance(p, MemoryPlanningPass):
new_gm.encounter_to_out_var_failure = True

new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module

# See the application of this function in exir/program/_program.py for more
# details on why this step is necessary.
if isinstance(p, SpecPropPass):
p.update_placeholder_tensor_specs(program, new_gm)

_copy_module(program.graph_module, new_gm)
else:
program = p(program)

return program


@final
class VulkanBackend(BackendDetails):
@classmethod
Expand All @@ -57,35 +89,44 @@ def preprocess( # noqa: C901
) -> PreprocessResult:
program = unsafe_remove_auto_functionalized_pass(program)

passes = [
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
SpecPropPass(),
ConstraintBasedSymShapeEvalPass(),
RemoveLocalScalarDenseOpsTransform(),
MemoryPlanningPass(),
]

new_gm = program.graph_module

for p in passes:
# This is a workaround to allow the memory planning pass to work without
# having to first apply ToOutVarPass(). See the `greedy()` function in
# `exir.memory_planning`; if this attribute isn't set, assertions in
# `collect_spec_from_nodes()` will fail.
if isinstance(p, MemoryPlanningPass):
new_gm.encounter_to_out_var_failure = True
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
# First, apply passes that fuse/remove operators to consolidate the graph
# structure but still preserve an "ATen-compliant" graph structure (i.e. all
# arguments to ATen operators must match the ATen function schema).
program = apply_passes(
program,
[
RemoveCloneOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
FuseViewCopyTransform(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
],
)

_copy_module(program.graph_module, new_gm)
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
# shapes and memory planning. Until this point, the graph must be ATen compliant
# because SpecPropPass will be calling the underlying ATen operators during its
# execution.
program = apply_passes(program, [SpecPropPass()])

# Apply graph transforms which either require `TensorSpec`s to have been created
# or would create an non ATen compliant graph structure.
program = apply_passes(
program,
[
# Since this pass may replace a scalar argument with a tensor argument,
# this pass may result in a non ATen compliant graph structure.
RemoveLocalScalarDenseOpsTransform(),
insert_prepack_nodes,
],
)

program = insert_prepack_nodes(program)
# Finally, apply dynamic shape passes and memory planning pass. These passes
# must be applied only when the graph structure is finalized.
program = apply_passes(
program, [ConstraintBasedSymShapeEvalPass(), MemoryPlanningPass()]
)

graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
Expand Down
Loading