From ba4bb54f8f41de5e06ed5b873a51b4f40c297050 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Mon, 4 Nov 2024 09:30:13 -0800 Subject: [PATCH] [ET-VK] Refine paritioner to account for storage type and memory layout ## Context There are a variety of ways that tensors can be represented in Vulkan. The two main descriptors for how a tensor is laid out in memory is: 1. Storage Type (buffer or texture) 2. Memory Layout (which dim is packed along a texel, which dim has a stride of 1, etc.) Due to the differences between buffers and textures, and the differences between different memory layouts, an implementation for an operator may only support a specific set of (storage type, memory layout) combinations. Furthermore, if an operator implementation supports multiple (storage type, memory layout) combinations, there may be a "preferred" setting which results in optimal performance. These changes lay the foundation for the implementation of a memory metadata tagging graph transform, which will make sure that all tensors participating in an operator call is has a valid/optimal (storage type, memory layout) setting, and insert transition operators to transfer input tensors to the correct memory settings when necessary. An additional change that is required arises from the fact that in Vulkan, there is a limit on texture and buffer sizes. Therefore, the partitioner needs to account for the storage types and memory layouts supported by the operator implementation, and check if all tensors participating in a computation can be represented with some storage type, memory layout combination supported by the implementation. ## Changes Improvements to the operator registry: * Introduce utility functions to check the optimal and enabled storage types and memory layouts for an operator Improvements to the Partitioner: * Account for the storage types and memory layouts supported by an operator when deciding if a node should be partitioned * Improved logic for fusable ops (i.e. the permute/transpose before a mm which can be fused into linear) to check if the final target op is supported in Vulkan, and only partition those nodes if so. Otherwise, don't partition it so that it can be fused by another backend. Differential Revision: [D65428843](https://our.internmc.facebook.com/intern/diff/D65428843/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 171 ++++++++++---- backends/vulkan/partitioner/TARGETS | 1 + .../vulkan/partitioner/vulkan_partitioner.py | 209 ++++++++++++------ backends/vulkan/targets.bzl | 1 + backends/vulkan/utils.py | 140 ++++++++++++ 5 files changed, 416 insertions(+), 106 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index fe67fdb30cf..e3e7e219d47 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,18 +8,31 @@ import operator -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Set, Union import executorch.backends.vulkan.custom_ops_lib # noqa import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.backends.vulkan.utils import ( + all_memory_layouts, + all_packed_dims, + PackedDim, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor +###################### +## OpFeatures class ## +###################### + def allow_node(node: torch.fx.Node) -> bool: return True @@ -27,25 +40,33 @@ def allow_node(node: torch.fx.Node) -> bool: class TextureImplFeatures: __slots__ = [ - # Indicates if the compute shader is agnostic to the packed dimension - "uses_packed_dim", - # Indicates if the compute shader is agnostic to the texture axis mapping + "valid_packed_dims", "uses_axis_map", - # Specifies a specific set of memory layouts that the shader supports. If it is - # and empty list, then the supported memory layouts can be inferred from the - # `uses_packed_dim` and `uses_axis_map` flags. - "supported_layouts", ] def __init__( self, - uses_packed_dim: bool = False, uses_axis_map: bool = False, - supported_layouts: Optional[List[VkMemoryLayout]] = None, + valid_packed_dims: Optional[Set[PackedDim]] = None, ): - self.uses_packed_dim: bool = uses_packed_dim self.uses_axis_map: bool = uses_axis_map - self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts + self.valid_packed_dims = set() + if valid_packed_dims is not None: + self.valid_packed_dims = valid_packed_dims + + def valid_memory_layouts(self) -> Set[VkMemoryLayout]: + layouts = set() + + if PackedDim.WIDTH in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) + + if PackedDim.HEIGHT in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) + + if PackedDim.CHANNELS in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + return layouts class OpFeatures: @@ -58,6 +79,9 @@ class OpFeatures: # bool indicating if the operator has a resize function, which allows it to # support dynamic shape tensors. "resize_fn", + # Optimal + "optimal_storage", + "optimal_layout", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. @@ -72,17 +96,64 @@ def __init__( texture_impl: Optional[TextureImplFeatures] = None, buffer_impl: bool = False, resize_fn: bool = False, + optimal_storage: Optional[VkStorageType] = None, + optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl self.buffer_impl: bool = buffer_impl self.resize_fn: bool = resize_fn + self.optimal_storage: Optional[VkStorageType] = optimal_storage + self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn + def propose_storage_type(self) -> Optional[VkStorageType]: + if self.optimal_storage is not None: + return self.optimal_storage + + if self.texture_impl is not None and not self.buffer_impl: + return VkStorageType.TEXTURE_3D + elif self.buffer_impl and self.texture_impl is None: + return VkStorageType.BUFFER + + return None + + def supported_storage_types(self) -> Set[VkStorageType]: + storage_types = set() + if self.texture_impl is not None: + storage_types.add(VkStorageType.TEXTURE_3D) + if self.buffer_impl: + storage_types.add(VkStorageType.BUFFER) + + return storage_types + + def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: + if self.optimal_layout is not None: + return self.optimal_layout + + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + possible_layouts = self.texture_impl.valid_memory_layouts() + if len(possible_layouts) == 1: + return next(iter(possible_layouts)) + + return None + + def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + return self.texture_impl.valid_memory_layouts() + else: + return all_memory_layouts + + +####################### +## Operator Registry ## +####################### OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] @@ -122,8 +193,8 @@ def update_features_impl(op: OpKey): ) def register_ephemeral_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -143,8 +214,8 @@ def register_ephemeral_op(features: OpFeatures): ) def register_binary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -170,8 +241,8 @@ def register_binary_op(features: OpFeatures): ) def register_unary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -181,8 +252,8 @@ def register_unary_op(features: OpFeatures): @update_features(exir_ops.edge.aten._to_copy.default) def register_to_copy_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -220,15 +291,16 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: ) def register_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=True, - supported_layouts=[ - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - ], + valid_packed_dims={ + PackedDim.WIDTH, + PackedDim.CHANNELS, + }, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -236,12 +308,13 @@ def register_mm_op(features: OpFeatures): @update_features(exir_ops.edge.aten._weight_int8pack_mm.default) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -249,11 +322,12 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -266,7 +340,7 @@ def register_int4_mm_op(features: OpFeatures): ) def register_softmax_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -282,7 +356,7 @@ def register_softmax_op(features: OpFeatures): ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -309,7 +383,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool: ) def register_2d_pool_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True return features @@ -323,9 +397,11 @@ def register_2d_pool_op(features: OpFeatures): ) def register_convolution_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True return features @@ -333,9 +409,11 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -343,7 +421,7 @@ def register_sdpa_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True return features @@ -352,7 +430,7 @@ def register_rotary_emb_op(features: OpFeatures): @update_features(exir_ops.edge.aten.view_copy.default) def register_view_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -393,7 +471,7 @@ def register_view_op(features: OpFeatures): ) def register_ported_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) return features @@ -408,15 +486,24 @@ def register_ported_op(features: OpFeatures): ) def register_ported_ops_with_prepacking(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.handles_own_prepacking = True return features -## -## Utility Functions -## +####################### +## Utility functions ## +####################### + + +def has_impl(target: OpKey) -> bool: + if not isinstance(target, str): + if target not in vulkan_supported_ops: + return target.name() in vulkan_supported_ops + return target in vulkan_supported_ops + else: + return target in vulkan_supported_ops def get_op_features(target: OpKey) -> OpFeatures: @@ -430,5 +517,13 @@ def get_op_features(target: OpKey) -> OpFeatures: return vulkan_supported_ops[target] +def optimal_storage_type(target: OpKey) -> Optional[VkStorageType]: + return get_op_features(target).optimal_storage + + +def optimal_memory_layout(target: OpKey) -> Optional[VkMemoryLayout]: + return get_op_features(target).optimal_layout + + def handles_own_prepacking(target: OpKey) -> bool: return get_op_features(target).handles_own_prepacking diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index d68a82ade05..1d1d29f6fb0 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/vulkan:op_registry", + "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", "//executorch/exir:delegate", "//executorch/exir:lib", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2e916fd5810..b7b015ae2dc 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -11,10 +11,19 @@ import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema +import executorch.backends.vulkan.utils as utils + import torch -from executorch.backends.vulkan.op_registry import vulkan_supported_ops +from executorch.backends.vulkan.op_registry import ( + get_op_features, + has_impl, + OpFeatures, + vulkan_supported_ops, +) +from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkStorageType from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend + from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -24,7 +33,6 @@ from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops -from torch._subclasses.fake_tensor import FakeTensor from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -40,81 +48,104 @@ class VulkanSupportedOperators(OperatorSupportBase): - def __init__(self, require_dynamic_shape: bool = False) -> None: + def __init__( + self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False + ) -> None: super().__init__() self.require_dynamic_shapes = require_dynamic_shape - # The tensor dim limit is to guard against tensors with one or more - # large dimensions, which cannot be represented by an image texture due - # to the texture axis limits. - self.tensor_dim_limit = 16384 - - # pyre-ignore - def node_val_is_compatible(self, node_val: Any) -> bool: - # Skip nodes that don't have a value - if node_val is None: - return True - - # TODO(ssjia) support symbolic ints - if isinstance(node_val, torch.SymInt): - return False - - if isinstance(node_val, FakeTensor): - # Vulkan currently only supports tensors of up to 4D - if len(node_val.shape) > 4: - return False + self.texture_limits: utils.ImageExtents = texture_limits - # bool dtype not currently supported - if node_val.dtype == torch.bool: - return False + def op_node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + """ + Check if a given node is compatible with the Vulkan delegate's implementation + of the operator called by the node. + """ + target = node.target + # Account for custom operators + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() - for dim in node_val.shape: - if dim > self.tensor_dim_limit: - return False + # Extract the features for the node's operator, if no override was provided + op_features = features + if features is None: + # pyre-ignore + if not has_impl(target): + return False, "no operator implementation" + # pyre-ignore + op_features = get_op_features(target) - if isinstance(node_val, (list, tuple)): - for item in node_val: - if not self.node_val_is_compatible(item): - return False + assert op_features is not None - return True + valid_texture_layouts = utils.possible_node_memory_layouts( + node, self.texture_limits + ) + for arg in node.args: + if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): + arg_texture_layouts = utils.possible_node_memory_layouts( + arg, self.texture_limits + ) + valid_texture_layouts = valid_texture_layouts.intersection( + arg_texture_layouts + ) + + # If there are no valid texture memory layouts, then buffer storage must be + # supported by the operator implementation. + if len(valid_texture_layouts) == 0: + # TODO: once memory metadata tagging pass is implemented, check that the + # op impl supports buffers instead + return False, "requires buffer representation" + + op_available_layouts = op_features.supported_memory_layouts( + VkStorageType.TEXTURE_3D + ) - def all_args_compatible(self, node: torch.fx.Node) -> bool: - node_val = node.meta.get("val", None) - if not self.node_val_is_compatible(node_val): - return False + is_compatible = any( + layout in op_available_layouts for layout in valid_texture_layouts + ) + if not is_compatible: + return False, "Required texutre memory layout not supported" - for arg in node.args: - if not isinstance(arg, torch.fx.Node): - continue + return is_compatible, "Op is compatible" - arg_val = arg.meta.get("val", None) - if not self.node_val_is_compatible(arg_val): - return False + def node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + # TODO(ssjia) support symbolic ints + if utils.is_symint_node(node): + return False, "symint node not supported yet" + elif utils.is_tensor_node(node): + return self.op_node_is_compatible(node, features=features) - return True + return False, f"Unsupported node type: {node.format_node()}" - def is_linear_permute(self, node: torch.fx.Node) -> bool: + def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: if node.target not in [ exir_ops.edge.aten.t_copy.default, exir_ops.edge.aten.permute_copy.default, ]: - return False + return False, False if len(node.users) != 1: - return False + return False, False first_user = list(node.users.keys())[0] if first_user.target in [ exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, ]: - # Only mark this node if the overall linear op is valid - if self.all_args_compatible(first_user): - return True + # Only mark this node if the target linear op is valid + if self.node_is_compatible(first_user)[0]: + return True, True + else: + return True, False - return False + return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: + def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Scalar tensors are usually converted to scalar values in the graph via` scalar_tensor[0].item()` in Python, which translates to a chain of @@ -126,18 +157,21 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: """ if node.target == exir_ops.edge.aten.select_copy.int: if len(node.users) != 1: - return False + return False, False # pyre-ignore if node.args[0].meta["val"].numel() != 1: - return False + return False, False + + local_scalar_dense = list(node.users.keys())[0] + if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: + return False, False - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default + return self.is_in_local_scalar_dense_chain(local_scalar_dense) if node.target == torch.ops.aten._local_scalar_dense.default: - return True + return True, all(self.node_is_compatible(user)[0] for user in node.users) - return False + return False, False def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": @@ -148,26 +182,35 @@ def log_skip(self, node: torch.fx.Node, reason: str) -> None: def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - r = self._is_node_supported(submodules, node) + r = self._is_node_supported(node) return r - def _is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: + def _is_node_supported(self, node: torch.fx.Node) -> bool: target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() - if self.is_linear_permute(node): + is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) + if is_linear_permute and target_linear_is_compatible: return True + elif is_linear_permute: + # Skip so that the permute can be fused into a linear by another backend + self.log_skip(node, "permute node of non compatible linear node") + return False - if self.is_in_local_scalar_dense_chain(node): + is_in_local_scalar_dense_chain, dst_node_is_compatible = ( + self.is_in_local_scalar_dense_chain(node) + ) + if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True + elif is_in_local_scalar_dense_chain: + self.log_skip(node, "local scalar dense of incompatible op node") + return False if target not in vulkan_supported_ops: - self.log_skip(node, "not in vulkan_supported_ops") + self.log_skip(node, "no operator implementation") return False features = vulkan_supported_ops[target] @@ -180,7 +223,11 @@ def _is_node_supported( self.log_skip(node, "no dynamic shape support") return False - return self.all_args_compatible(node) + is_compatible, reason = self.node_is_compatible(node, features=features) + if not is_compatible: + self.log_skip(node, reason) + + return is_compatible def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: @@ -193,6 +240,23 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) + if key == "texture_limits": + compile_specs.append( + CompileSpec( + "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") + ) + ) + # Unhandled options are ignored return compile_specs @@ -200,7 +264,10 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: @final class VulkanPartitioner(Partitioner): - def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + compile_options: Optional[Dict[str, Any]] = None, + ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: self.options = compile_options @@ -218,9 +285,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + texture_limits: utils.ImageExtents = self.options.get( + "texture_limits", (16384, 16384, 2048) + ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - VulkanSupportedOperators(self.options.get("require_dynamic_shapes", False)), + VulkanSupportedOperators( + texture_limits, + require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 0d3b17ccccc..9785b349516 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -253,6 +253,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":custom_ops_lib", + ":utils_lib", "//caffe2:torch", "//executorch/exir/dialects:lib", "//executorch/backends/vulkan/serialization:lib", diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index ae0b8c69406..16077f008d7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,11 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum +from typing import Set, Tuple + import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from torch._export.utils import is_buffer, is_param +from torch._subclasses.fake_tensor import FakeTensor + from torch.export import ExportedProgram +## +## Node type determination +## + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -28,3 +42,129 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: or is_buffer(program, node) or is_constant(program, node) ) + + +def is_symint_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a SymInt value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], torch.SymInt): + return True + + return False + + +def is_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a tensor value, or a collection of tensor values + """ + # All nodes with tensor values are tagged by the SpecPropPass transform + if "spec" in node.meta: + return True + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +## +## Memory Layout, Storage Type Determination +## + +ImageExtents = Tuple[int, int, int] + + +class PackedDim(IntEnum): + WIDTH = 0 + HEIGHT = 1 + CHANNELS = 2 + + +all_packed_dims: Set[PackedDim] = { + PackedDim.WIDTH, + PackedDim.HEIGHT, + PackedDim.CHANNELS, +} + +all_storage_types: Set[VkStorageType] = { + VkStorageType.BUFFER, + VkStorageType.TEXTURE_3D, +} + +all_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, +} + + +def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: + """ + Calculate the image extents that will be used to represent a tensor with the given sizes + and memory layout in the Vulkan Delegate. + """ + width = sizes[-1] if len(sizes) >= 1 else 1 + height = sizes[-2] if len(sizes) >= 2 else 1 + channels = sizes[-3] if len(sizes) >= 3 else 1 + batch = sizes[0] if len(sizes) >= 4 else 1 + + if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: + width = (width + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: + height = (height + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: + channels = (channels + 3) // 4 + else: + raise RuntimeError(f"Unsupported memory layout {layout}") + + return width, height, channels * batch + + +def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: + return all(extents[i] <= limits[i] for i in range(len(extents))) + + +def valid_texture_memory_layouts( + tensor_sizes: torch.Size, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given tensor sizes, determine the set of memory layouts which will prodice a texture + that can fit within the specified device limits. + """ + valid_layouts = set() + for layout in list(all_memory_layouts): + extents = required_image_extents(tensor_sizes, layout) + if extents_are_valid(extents, texture_limits): + valid_layouts.add(layout) + + return valid_layouts + + +def possible_node_memory_layouts( + node: torch.fx.Node, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given a node, determine the set of memory layouts which can be used to represent all + tensors involved in the computation. + """ + assert is_tensor_node(node) + if isinstance(node.meta["val"], FakeTensor): + return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) + valid_layouts = set() + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + valid_layouts = valid_layouts.union( + valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + ) + + return valid_layouts