From 38e09c64ff4f4458bbe20167c2944bd671359c77 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Thu, 31 Oct 2024 12:31:27 -0700 Subject: [PATCH] [ET-VK][AOT][ez] Introduce vulkan export utils lib ## Changes As title. Introduce a common Python utility library for scripts in the Vulkan backend. Differential Revision: [D65291064](https://our.internmc.facebook.com/intern/diff/D65291064/) [ghstack-poisoned] --- backends/vulkan/_passes/TARGETS | 1 + .../vulkan/_passes/insert_prepack_nodes.py | 21 +---------- .../serialization/vulkan_graph_builder.py | 37 +++++-------------- backends/vulkan/targets.bzl | 13 +++++++ backends/vulkan/utils.py | 30 +++++++++++++++ 5 files changed, 56 insertions(+), 46 deletions(-) create mode 100644 backends/vulkan/utils.py diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 3f328deb485..cf50f170cf3 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -12,6 +12,7 @@ runtime.python_library( deps = [ "//caffe2:torch", "//executorch/exir:pass_base", + "//executorch/backends/vulkan:utils_lib", ], ) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index 1d3f4047efe..5dd01aeedfe 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -13,10 +13,10 @@ import torch from executorch.backends.vulkan.op_registry import handles_own_prepacking +from executorch.backends.vulkan.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops -from torch._export.utils import is_buffer, is_param from torch.export import ExportedProgram @@ -31,25 +31,8 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: argument into the operator implementation. """ - def is_get_attr_node(node: torch.fx.Node) -> bool: - return isinstance(node, torch.fx.Node) and node.op == "get_attr" - - def is_constant(node: torch.fx.Node) -> bool: - return node.name in program.graph_signature.inputs_to_lifted_tensor_constants - - def is_param_node(node: torch.fx.Node) -> bool: - """ - Check if the given node is a parameter within the exported program - """ - return ( - is_get_attr_node(node) - or is_param(program, node) - or is_buffer(program, node) - or is_constant(node) - ) - def prepack_not_required(node: torch.fx.Node) -> bool: - if not is_param_node(node): + if not is_param_node(program, node): return True for user in node.users: diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index f9ae83ddc68..bc77bc40cfb 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -12,6 +12,11 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema import torch +from executorch.backends.vulkan.utils import ( + is_constant, + is_get_attr_node, + is_param_node, +) from executorch.exir.backend.utils import DelegateMappingBuilder from executorch.exir.tensor import TensorSpec @@ -68,34 +73,12 @@ def get_vk_datatype(torch_dtype: torch.dtype) -> vk_graph_schema.VkDataType: else: raise AssertionError(f"Invalid dtype for vulkan_preprocess ({torch_dtype})") - def is_constant(self, node: Node): - return ( - node.name in self.program.graph_signature.inputs_to_lifted_tensor_constants - ) - - def is_get_attr_node(self, node: Node) -> bool: - """ - Returns true if the given node is a get attr node for a tensor of the model - """ - return isinstance(node, Node) and node.op == "get_attr" - - def is_param_node(self, node: Node) -> bool: - """ - Check if the given node is a parameter within the exported program - """ - return ( - self.is_get_attr_node(node) - or is_param(self.program, node) - or is_buffer(self.program, node) - or self.is_constant(node) - ) - def get_constant(self, node: Node) -> Optional[torch.Tensor]: """ Returns the constant associated with the given node in the exported program. Returns None if the node is not a constant within the exported program """ - if self.is_constant(node): + if is_constant(self.program, node): constant_name = ( self.program.graph_signature.inputs_to_lifted_tensor_constants[ node.name @@ -116,9 +99,9 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: tensor = get_param(self.program, node) elif is_buffer(self.program, node): tensor = get_buffer(self.program, node) - elif self.is_constant(node): + elif is_constant(self.program, node): tensor = self.get_constant(node) - elif self.is_get_attr_node(node): + elif is_get_attr_node(node): # This is a hack to support both lifted and unlifted graph try: tensor = getattr(node.graph.owning_module, node.target) @@ -132,7 +115,7 @@ def get_param_tensor(self, node: Node) -> torch.Tensor: def maybe_add_constant_tensor(self, node: Node) -> int: constant_id = -1 - if self.is_param_node(node): + if is_param_node(self.program, node): constant_id = len(self.const_tensors) self.const_tensors.append(self.get_param_tensor(node)) @@ -280,7 +263,7 @@ def process_placeholder_node(self, node: Node) -> None: if len(node.users) == 0: return None ids = self.create_node_value(node) - if not self.is_param_node(node): + if not is_param_node(self.program, node): if isinstance(ids, int): self.input_ids.append(ids) else: diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 38e0183318c..994c7473943 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -204,6 +204,19 @@ def define_common_targets(is_fbcode = False): ## AOT targets ## if is_fbcode: + runtime.python_library( + name = "utils_lib", + srcs = [ + "utils.py", + ], + visibility = [ + "//executorch/backends/vulkan/...", + ], + deps = [ + "//caffe2:torch", + ] + ) + runtime.python_library( name = "custom_ops_lib", srcs = [ diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py new file mode 100644 index 00000000000..ae0b8c69406 --- /dev/null +++ b/backends/vulkan/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch._export.utils import is_buffer, is_param + +from torch.export import ExportedProgram + + +def is_get_attr_node(node: torch.fx.Node) -> bool: + return isinstance(node, torch.fx.Node) and node.op == "get_attr" + + +def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool: + return node.name in program.graph_signature.inputs_to_lifted_tensor_constants + + +def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: + """ + Check if the given node is a parameter within the exported program + """ + return ( + is_get_attr_node(node) + or is_param(program, node) + or is_buffer(program, node) + or is_constant(program, node) + )