diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index fe67fdb30cf..3a6191bccb6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,18 +8,31 @@ import operator -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Optional, Set, Union import executorch.backends.vulkan.custom_ops_lib # noqa import torch -from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) + +from executorch.backends.vulkan.utils import ( + all_memory_layouts, + all_packed_dims, + PackedDim, +) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor +###################### +## OpFeatures class ## +###################### + def allow_node(node: torch.fx.Node) -> bool: return True @@ -27,25 +40,37 @@ def allow_node(node: torch.fx.Node) -> bool: class TextureImplFeatures: __slots__ = [ - # Indicates if the compute shader is agnostic to the packed dimension - "uses_packed_dim", - # Indicates if the compute shader is agnostic to the texture axis mapping + "valid_packed_dims", "uses_axis_map", - # Specifies a specific set of memory layouts that the shader supports. If it is - # and empty list, then the supported memory layouts can be inferred from the - # `uses_packed_dim` and `uses_axis_map` flags. - "supported_layouts", ] def __init__( self, - uses_packed_dim: bool = False, uses_axis_map: bool = False, - supported_layouts: Optional[List[VkMemoryLayout]] = None, + valid_packed_dims: Optional[Set[PackedDim]] = None, ): - self.uses_packed_dim: bool = uses_packed_dim self.uses_axis_map: bool = uses_axis_map - self.supported_layouts: Optional[List[VkMemoryLayout]] = supported_layouts + self.valid_packed_dims = set() + if valid_packed_dims is not None: + self.valid_packed_dims = valid_packed_dims + + def valid_memory_layouts(self) -> Set[VkMemoryLayout]: + """ + Derive the set of memory layouts supported by the texture implementation based + on the valid packed dimensions. + """ + layouts = set() + + if PackedDim.WIDTH in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) + + if PackedDim.HEIGHT in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) + + if PackedDim.CHANNELS in self.valid_packed_dims: + layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) + + return layouts class OpFeatures: @@ -58,6 +83,9 @@ class OpFeatures: # bool indicating if the operator has a resize function, which allows it to # support dynamic shape tensors. "resize_fn", + # Optimal + "optimal_storage", + "optimal_layout", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. @@ -72,17 +100,90 @@ def __init__( texture_impl: Optional[TextureImplFeatures] = None, buffer_impl: bool = False, resize_fn: bool = False, + optimal_storage: Optional[VkStorageType] = None, + optimal_layout: Optional[VkMemoryLayout] = None, handles_own_prepacking: bool = False, check_node_fn: Optional[Callable] = None, ): self.texture_impl: Optional[TextureImplFeatures] = texture_impl self.buffer_impl: bool = buffer_impl self.resize_fn: bool = resize_fn + self.optimal_storage: Optional[VkStorageType] = optimal_storage + self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout self.handles_own_prepacking: bool = handles_own_prepacking self.check_node_fn: Callable = allow_node if check_node_fn is not None: self.check_node_fn = check_node_fn + def propose_storage_type(self) -> Optional[VkStorageType]: + """ + Propose a storage type that should be used for this operator. A proposal can be + made if one of the following is true: + 1. The operator specifies an optimal storage type + 2. Only one storage type is supported. + + If both storage types are supported and no optimal storage type is specified, + then None is returned to indicate that there is no preference in storage type. + """ + if self.optimal_storage is not None: + return self.optimal_storage + + if self.texture_impl is not None and not self.buffer_impl: + return VkStorageType.TEXTURE_3D + elif self.buffer_impl and self.texture_impl is None: + return VkStorageType.BUFFER + + return None + + def supported_storage_types(self) -> Set[VkStorageType]: + """ + Return the set of storage types supported by this operator. + """ + storage_types = set() + if self.texture_impl is not None: + storage_types.add(VkStorageType.TEXTURE_3D) + if self.buffer_impl: + storage_types.add(VkStorageType.BUFFER) + + return storage_types + + def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: + """ + Given a storage type as a precondition, propose a memory layout that should be + used for this operator. A proposal can be made if one of the following is true: + 1. The operator specifies an optimal memory layout + 2. Only one memory layout is supported. + + If multiple memory layouts are supported and no optimal memory layout is + specified then return None to indicate that the "best" memory layout for the + operator is ambiguous. + """ + if self.optimal_layout is not None: + return self.optimal_layout + + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + possible_layouts = self.texture_impl.valid_memory_layouts() + if len(possible_layouts) == 1: + return next(iter(possible_layouts)) + + return None + + def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: + """ + Return the set of memory layouts supported by this operator for a given storage + type. + """ + if storage == VkStorageType.TEXTURE_3D: + assert self.texture_impl is not None + return self.texture_impl.valid_memory_layouts() + else: + return all_memory_layouts + + +####################### +## Operator Registry ## +####################### OpKey = Union[str, torch._ops.OpOverload, EdgeOpOverload] @@ -122,8 +223,8 @@ def update_features_impl(op: OpKey): ) def register_ephemeral_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -143,8 +244,8 @@ def register_ephemeral_op(features: OpFeatures): ) def register_binary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -170,8 +271,8 @@ def register_binary_op(features: OpFeatures): ) def register_unary_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.buffer_impl = True features.resize_fn = True @@ -181,8 +282,8 @@ def register_unary_op(features: OpFeatures): @update_features(exir_ops.edge.aten._to_copy.default) def register_to_copy_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, uses_axis_map=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -220,15 +321,16 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: ) def register_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=True, - supported_layouts=[ - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - ], + valid_packed_dims={ + PackedDim.WIDTH, + PackedDim.CHANNELS, + }, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -236,12 +338,13 @@ def register_mm_op(features: OpFeatures): @update_features(exir_ops.edge.aten._weight_int8pack_mm.default) def register_int8_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.buffer_impl = True features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -249,11 +352,12 @@ def register_int8_mm_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.linear_weight_int4.default) def register_int4_mm_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=False, uses_axis_map=False, - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -266,7 +370,7 @@ def register_int4_mm_op(features: OpFeatures): ) def register_softmax_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -282,7 +386,7 @@ def register_softmax_op(features: OpFeatures): ) def register_reduce_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True @@ -309,7 +413,7 @@ def check_reduce_node(node: torch.fx.Node) -> bool: ) def register_2d_pool_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True return features @@ -323,9 +427,11 @@ def register_2d_pool_op(features: OpFeatures): ) def register_convolution_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED features.handles_own_prepacking = True return features @@ -333,9 +439,11 @@ def register_convolution_op(features: OpFeatures): @update_features("llama::sdpa_with_kv_cache") def register_sdpa_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True + features.optimal_storage = VkStorageType.TEXTURE_3D + features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED features.handles_own_prepacking = True return features @@ -343,7 +451,7 @@ def register_sdpa_op(features: OpFeatures): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_WIDTH_PACKED], + valid_packed_dims={PackedDim.WIDTH}, ) features.resize_fn = True return features @@ -352,7 +460,7 @@ def register_rotary_emb_op(features: OpFeatures): @update_features(exir_ops.edge.aten.view_copy.default) def register_view_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - uses_packed_dim=True, + valid_packed_dims=all_packed_dims, ) features.resize_fn = True return features @@ -393,7 +501,7 @@ def register_view_op(features: OpFeatures): ) def register_ported_op(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) return features @@ -408,15 +516,24 @@ def register_ported_op(features: OpFeatures): ) def register_ported_ops_with_prepacking(features: OpFeatures): features.texture_impl = TextureImplFeatures( - supported_layouts=[VkMemoryLayout.TENSOR_CHANNELS_PACKED], + valid_packed_dims={PackedDim.CHANNELS}, ) features.handles_own_prepacking = True return features -## -## Utility Functions -## +####################### +## Utility functions ## +####################### + + +def has_impl(target: OpKey) -> bool: + if not isinstance(target, str): + if target not in vulkan_supported_ops: + return target.name() in vulkan_supported_ops + return target in vulkan_supported_ops + else: + return target in vulkan_supported_ops def get_op_features(target: OpKey) -> OpFeatures: diff --git a/backends/vulkan/partitioner/TARGETS b/backends/vulkan/partitioner/TARGETS index d68a82ade05..1d1d29f6fb0 100644 --- a/backends/vulkan/partitioner/TARGETS +++ b/backends/vulkan/partitioner/TARGETS @@ -13,6 +13,7 @@ runtime.python_library( ], deps = [ "//executorch/backends/vulkan:op_registry", + "//executorch/backends/vulkan:utils_lib", "//executorch/backends/vulkan:vulkan_preprocess", "//executorch/exir:delegate", "//executorch/exir:lib", diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 2e916fd5810..c851eeb4dae 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -9,12 +9,23 @@ import logging from typing import Any, Callable, Dict, final, List, Mapping, Optional, Tuple -import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema +import executorch.backends.vulkan.utils as utils import torch -from executorch.backends.vulkan.op_registry import vulkan_supported_ops +from executorch.backends.vulkan.op_registry import ( + get_op_features, + has_impl, + OpFeatures, + vulkan_supported_ops, +) + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend + from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -24,7 +35,6 @@ from executorch.exir.backend.utils import tag_constant_data from executorch.exir.dialects._ops import ops as exir_ops -from torch._subclasses.fake_tensor import FakeTensor from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner @@ -40,104 +50,140 @@ class VulkanSupportedOperators(OperatorSupportBase): - def __init__(self, require_dynamic_shape: bool = False) -> None: + def __init__( + self, texture_limits: utils.ImageExtents, require_dynamic_shape: bool = False + ) -> None: super().__init__() self.require_dynamic_shapes = require_dynamic_shape - # The tensor dim limit is to guard against tensors with one or more - # large dimensions, which cannot be represented by an image texture due - # to the texture axis limits. - self.tensor_dim_limit = 16384 - - # pyre-ignore - def node_val_is_compatible(self, node_val: Any) -> bool: - # Skip nodes that don't have a value - if node_val is None: - return True + self.texture_limits: utils.ImageExtents = texture_limits - # TODO(ssjia) support symbolic ints - if isinstance(node_val, torch.SymInt): - return False - - if isinstance(node_val, FakeTensor): - # Vulkan currently only supports tensors of up to 4D - if len(node_val.shape) > 4: - return False + def op_node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + """ + Check if a given node is compatible with the Vulkan delegate's implementation + of the operator called by the node. Each tensor argument participating in the + operator call must be able to be represented with a (storage type, memory layout) + combination that is supported by the operator implementation. + """ + target = node.target + # Account for custom operators + if node.target == torch.ops.higher_order.auto_functionalized: + first_arg = node.args[0] + assert isinstance(first_arg, torch._ops.OpOverload) + target = first_arg.name() - # bool dtype not currently supported - if node_val.dtype == torch.bool: - return False + # Extract the features for the node's operator, if no override was provided + if features is None: + if not has_impl(target): + return False, "no operator implementation" + features = get_op_features(target) - for dim in node_val.shape: - if dim > self.tensor_dim_limit: - return False + valid_texture_layouts = utils.possible_node_memory_layouts( + node, self.texture_limits + ) + for arg in node.args: + if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): + arg_texture_layouts = utils.possible_node_memory_layouts( + arg, self.texture_limits + ) + valid_texture_layouts = valid_texture_layouts.intersection( + arg_texture_layouts + ) + + # If there are no valid texture memory layouts, then buffer storage must be + # supported by the operator implementation. + if len(valid_texture_layouts) == 0: + # TODO: once memory metadata tagging pass is implemented, check that the + # op impl supports buffers instead + return False, "requires buffer representation" + + op_available_layouts = features.supported_memory_layouts( + VkStorageType.TEXTURE_3D + ) - if isinstance(node_val, (list, tuple)): - for item in node_val: - if not self.node_val_is_compatible(item): - return False + is_compatible = any( + layout in op_available_layouts for layout in valid_texture_layouts + ) + if not is_compatible: + return False, "Required texutre memory layout not supported" - return True + return is_compatible, "Op is compatible" - def all_args_compatible(self, node: torch.fx.Node) -> bool: - node_val = node.meta.get("val", None) - if not self.node_val_is_compatible(node_val): - return False + def node_is_compatible( + self, node: torch.fx.Node, features: Optional[OpFeatures] = None + ) -> Tuple[bool, str]: + # TODO(ssjia) support symbolic ints + if utils.is_symint_node(node): + return False, "symint node not supported yet" + elif utils.is_tensor_node(node): + return self.op_node_is_compatible(node, features=features) - for arg in node.args: - if not isinstance(arg, torch.fx.Node): - continue + return False, f"Unsupported node type: {node.format_node()}" - arg_val = arg.meta.get("val", None) - if not self.node_val_is_compatible(arg_val): - return False + def is_linear_permute(self, node: torch.fx.Node) -> Tuple[bool, bool]: + """ + Detect if a node is a permute/transpose that precedes a call to a `mm` or + `addmm` operator. This node can be fused with the `mm` or `addmm` to produce a + `linear` operator. - return True + This function returns two bool values: + 1. The first indicates if this node can be fused into a linear node + 2. The second indicates if the overall linear op can be executed with Vulkan - def is_linear_permute(self, node: torch.fx.Node) -> bool: + The node will be partitioned only if both are true. + """ if node.target not in [ exir_ops.edge.aten.t_copy.default, exir_ops.edge.aten.permute_copy.default, ]: - return False + return False, False if len(node.users) != 1: - return False + return False, False first_user = list(node.users.keys())[0] if first_user.target in [ exir_ops.edge.aten.mm.default, exir_ops.edge.aten.addmm.default, ]: - # Only mark this node if the overall linear op is valid - if self.all_args_compatible(first_user): - return True + # Only mark this node if the target linear op is valid + if self.node_is_compatible(first_user)[0]: + return True, True + else: + return True, False - return False + return False, False - def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool: + def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, bool]: """ Scalar tensors are usually converted to scalar values in the graph via` scalar_tensor[0].item()` in Python, which translates to a chain of `local_scalar_dense(torch.select.int(scalar_tensor, 0, 0))` in the graph. This function marks the entire chain as supported by the Vulkan delegate. - Later, within vulkan_preprocess there will be a graph transform which - replaces the chain with passing in the scalar tensor directly. + Later, within vulkan_preprocess there will be a graph transform which replaces + the chain with passing in the scalar tensor directly. + + Similar to the `is_linear_permute` function, this function has 2 return values. """ if node.target == exir_ops.edge.aten.select_copy.int: if len(node.users) != 1: - return False + return False, False # pyre-ignore if node.args[0].meta["val"].numel() != 1: - return False + return False, False + + local_scalar_dense = list(node.users.keys())[0] + if local_scalar_dense.target != torch.ops.aten._local_scalar_dense.default: + return False, False - user = list(node.users.keys())[0] - return user.target == torch.ops.aten._local_scalar_dense.default + return self.is_in_local_scalar_dense_chain(local_scalar_dense) if node.target == torch.ops.aten._local_scalar_dense.default: - return True + return True, all(self.node_is_compatible(user)[0] for user in node.users) - return False + return False, False def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": @@ -148,26 +194,35 @@ def log_skip(self, node: torch.fx.Node, reason: str) -> None: def is_node_supported( self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: - r = self._is_node_supported(submodules, node) + r = self._is_node_supported(node) return r - def _is_node_supported( - self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node - ) -> bool: + def _is_node_supported(self, node: torch.fx.Node) -> bool: target = node.target if node.target == torch.ops.higher_order.auto_functionalized: first_arg = node.args[0] assert isinstance(first_arg, torch._ops.OpOverload) target = first_arg.name() - if self.is_linear_permute(node): + is_linear_permute, target_linear_is_compatible = self.is_linear_permute(node) + if is_linear_permute and target_linear_is_compatible: return True + elif is_linear_permute: + # Skip so that the permute can be fused into a linear by another backend + self.log_skip(node, "permute node of non compatible linear node") + return False - if self.is_in_local_scalar_dense_chain(node): + is_in_local_scalar_dense_chain, dst_node_is_compatible = ( + self.is_in_local_scalar_dense_chain(node) + ) + if is_in_local_scalar_dense_chain and dst_node_is_compatible: return True + elif is_in_local_scalar_dense_chain: + self.log_skip(node, "local scalar dense of incompatible op node") + return False if target not in vulkan_supported_ops: - self.log_skip(node, "not in vulkan_supported_ops") + self.log_skip(node, "no operator implementation") return False features = vulkan_supported_ops[target] @@ -180,19 +235,38 @@ def _is_node_supported( self.log_skip(node, "no dynamic shape support") return False - return self.all_args_compatible(node) + is_compatible, reason = self.node_is_compatible(node, features=features) + if not is_compatible: + self.log_skip(node, reason) + + return is_compatible def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: compile_specs = [] for key, value in compile_options.items(): - if isinstance( - value, (vk_graph_schema.VkStorageType, vk_graph_schema.VkMemoryLayout) - ): + if isinstance(value, (VkStorageType, VkMemoryLayout)): value_bytes = int(value).to_bytes(4, byteorder="little") compile_specs.append(CompileSpec(key, value_bytes)) + if key == "texture_limits": + compile_specs.append( + CompileSpec( + "texture_limits_x", int(value[0]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_y", int(value[1]).to_bytes(4, byteorder="little") + ) + ) + compile_specs.append( + CompileSpec( + "texture_limits_z", int(value[2]).to_bytes(4, byteorder="little") + ) + ) + # Unhandled options are ignored return compile_specs @@ -200,7 +274,10 @@ def parse_compile_options(compile_options: Dict[str, Any]) -> List[CompileSpec]: @final class VulkanPartitioner(Partitioner): - def __init__(self, compile_options: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + compile_options: Optional[Dict[str, Any]] = None, + ) -> None: self.options: Dict[str, Any] = {} if compile_options is not None: self.options = compile_options @@ -218,9 +295,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: # subgraphs containing the nodes with the tags partition_tags = {} + texture_limits: utils.ImageExtents = self.options.get( + "texture_limits", utils.DEFAULT_TEXTURE_LIMITS + ) capability_partitioner = CapabilityBasedPartitioner( exported_program.graph_module, - VulkanSupportedOperators(self.options.get("require_dynamic_shapes", False)), + VulkanSupportedOperators( + texture_limits, + require_dynamic_shape=self.options.get("require_dynamic_shapes", False), + ), allows_single_node_partition=True, ) partition_list = capability_partitioner.propose_partitions() diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 0d3b17ccccc..9785b349516 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -253,6 +253,7 @@ def define_common_targets(is_fbcode = False): ], deps = [ ":custom_ops_lib", + ":utils_lib", "//caffe2:torch", "//executorch/exir/dialects:lib", "//executorch/backends/vulkan/serialization:lib", diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index ae0b8c69406..4264e942719 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,11 +4,25 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum +from typing import Set, Tuple + import torch + +from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( + VkMemoryLayout, + VkStorageType, +) from torch._export.utils import is_buffer, is_param +from torch._subclasses.fake_tensor import FakeTensor + from torch.export import ExportedProgram +## +## Node type determination +## + def is_get_attr_node(node: torch.fx.Node) -> bool: return isinstance(node, torch.fx.Node) and node.op == "get_attr" @@ -28,3 +42,131 @@ def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool: or is_buffer(program, node) or is_constant(program, node) ) + + +def is_symint_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a SymInt value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], torch.SymInt): + return True + + return False + + +def is_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a tensor value, or a collection of tensor values + """ + # All nodes with tensor values are tagged by the SpecPropPass transform + if "spec" in node.meta: + return True + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +## +## Memory Layout, Storage Type Determination +## + +ImageExtents = Tuple[int, int, int] + +DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) + + +class PackedDim(IntEnum): + WIDTH = 0 + HEIGHT = 1 + CHANNELS = 2 + + +all_packed_dims: Set[PackedDim] = { + PackedDim.WIDTH, + PackedDim.HEIGHT, + PackedDim.CHANNELS, +} + +all_storage_types: Set[VkStorageType] = { + VkStorageType.BUFFER, + VkStorageType.TEXTURE_3D, +} + +all_memory_layouts: Set[VkMemoryLayout] = { + VkMemoryLayout.TENSOR_WIDTH_PACKED, + VkMemoryLayout.TENSOR_HEIGHT_PACKED, + VkMemoryLayout.TENSOR_CHANNELS_PACKED, +} + + +def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents: + """ + Calculate the image extents that will be used to represent a tensor with the given sizes + and memory layout in the Vulkan Delegate. + """ + width = sizes[-1] if len(sizes) >= 1 else 1 + height = sizes[-2] if len(sizes) >= 2 else 1 + channels = sizes[-3] if len(sizes) >= 3 else 1 + batch = sizes[0] if len(sizes) >= 4 else 1 + + if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED: + width = (width + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED: + height = (height + 3) // 4 + elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED: + channels = (channels + 3) // 4 + else: + raise RuntimeError(f"Unsupported memory layout {layout}") + + return width, height, channels * batch + + +def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool: + return all(extents[i] <= limits[i] for i in range(len(extents))) + + +def valid_texture_memory_layouts( + tensor_sizes: torch.Size, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given tensor sizes, determine the set of memory layouts which will prodice a texture + that can fit within the specified device limits. + """ + valid_layouts = set() + for layout in list(all_memory_layouts): + extents = required_image_extents(tensor_sizes, layout) + if extents_are_valid(extents, texture_limits): + valid_layouts.add(layout) + + return valid_layouts + + +def possible_node_memory_layouts( + node: torch.fx.Node, texture_limits: ImageExtents +) -> Set[VkMemoryLayout]: + """ + Given a node, determine the set of memory layouts which can be used to represent all + tensors involved in the computation. + """ + assert is_tensor_node(node) + if isinstance(node.meta["val"], FakeTensor): + return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) + valid_layouts = set() + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + for fake_tensor in node.meta["val"]: + valid_layouts = valid_layouts.union( + valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + ) + + return valid_layouts