From a55bc0cd1714b2ca931d0b65a89fb0a365bb2026 Mon Sep 17 00:00:00 2001 From: Stephen Jia Date: Wed, 30 Jul 2025 17:43:59 -0700 Subject: [PATCH] Rewrite Memory Metadata Tagging Pass (#12927) Summary: ## Context In ET-VK, tensors may be stored with either a GPU buffer or a GPU texture. They may also be stored with a specific memory layout: width packed, height packed, or channels packed. The memory layout controls which dimension will have its elements be adjacent in physical memory. In this way, the "representation" of tensors in ET-VK may be described with a storage type, memory layout pair. Operator implementations may only support certain tensor representations for inputs and outputs. Furthermore, implementations typically have expectations around which input/output tensors will share the same representation. Some examples: * Binary Operators: * I/O tensors may use any representation; however, all tensors in the op must use the same representation. i.e. If the first input tensor uses buffer storage, so must the other tensor and the output tensor * Native Group Norm: *Input tensors must be a channels packed texture. However, the op produces 3 outputs: the normalized tensor, the running mean, and the running stddev. The normalized tensor must use the same representation as the first input. However, the mean and stddev tensors are expected to be contiguous buffers. * Choose qparams: * The Input tensor can use any representation. However, the two output tensors (zero points and scales) will always be contiguous buffers * Dynamically quantized linear: * The input tensor can be either buffer or texture, but must be contiguous/width packed. The scales and zeros tensors for the inputs and weights must all be contiguous buffers. The output tensor must be the same representation as the input tensors. The operator registry (`op_registry.py`) is responsible for denoting these representational requirements for each op, and the `tag_memory_metadata_pass.py` graph pass is responsible for determining what representation each tensor in each operator should use. The graph pass is also responsible for inserting nodes to move input arguments to a required representation, if they have been created with a non-supported representation. ## Current Method Currently, the operator registry will indicate the following: * Are texture inputs supported for the op * If yes, which texture memory layouts are supported for inputs to the op * Are buffer inputs supported for the op * An "optimal" storage type and memory layout to use for inputs/outputs of the operator. The underlying assumption is that all tensors participating in an operator will use the same representation for all tensors. Although this assumption holds true for most operators, this assumption is clearly insufficient for some of the example operators described above, where some input tensors may require that certain inputs use specific representations that are different from other tensors. During export, the memory metadata tagging pass will go through each op and mark the tensors participating in the op with a valid representation for that op. It will ensure that all tensors participating in an op will use the same representation. To determine the representation to use, it accounts for three things in order of priority: * The "optimal" storage type and memory layout marked for the op in the operator registry * Any existing representation that have already been determined for input tensors * What representations are supported by users of the output tensor of the current op ## Goals of this diff The main goal of this diff is to address the problem that the current method of annotating tensor representation requirements for operators is insufficient for describing the tensor representation requirements for operator implementation. Critically, for operators like choose_qparams and dynamically quantized linear, the current system cannot ensure that all input/output tensors are using representations that are supported by the op impl, since the current system tries to make all tensors participating in an operator use the same representation. ## Changes ### `utils.py` First, in 'utils.py` I introduce several classes to abstract the concept of tensor representations and sets of possible tensor representations. `TensorRepr` represents a pair of storage type + memory layout which describes the representation to use for a single tensor. `TensorRepSet` represents the set of possible representations that may be used for a single tensor. `OpRepSet` manages the set of possible representations (i.e. `TensorRepSet`s) for all tensors participating in a operation. To do this, it accounts for 3 things: * The supported tensor representations for input/output that are denoted by the operator registration * The actual sizes of the tensor - some tensors may have dims that are too large to fit into a texture. * Sync requirements, i.e. requirements re: which tensors in the operation must use the same representation For the last point, `OpRepSet` accounts for three "rules" internally: * All input tensors must use the same representation * All output tensors must use the same representation * The "primary" (i.e. first) input and output tensors must use the same representation I have settled on these three rules for now since they adequately describe the possible requirements of all operators. These three rules are validated to be true at all times within `OpRepSet`. Since `TensorRepSet`s may be ambiguous (i.e. there are multiple possible representations that could be used), `OpRepSet` also provides utility functions to constrain the possible representation set of an input operator while maintaining the synchronization rules. I have also defined `TensorRepSet` instances like: * `utils.ANY_STORAGE` * `utils.CONTIGUOUS_BUFFER` * `utils.CHANNELS_PACKED_TEXTURE` as convenience definitions for common representation set configurations. ### `op_registry.py` Now, in `op_registry.py` operator registrations only need to define 2 things: `input_storages` and optionally `output_storages`, which describe the possible representation sets that may be used for input and output tensors. The registrations for each example operator would be: ``` # binary ops def register_binary_op(): return OpFeatures( inputs_storage=utils.ANY_STORAGE, supports_resize=True, ) # group norm def register_native_group_norm(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, outputs_storage=[ utils.CHANNELS_PACKED_TEXTURE, utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER, ], supports_prepacking=True, ) # choose qparams update_features( [ exir_ops.edge.torchao.choose_qparams_affine.default, ] ) def register_torchao_quantization_op(): return OpFeatures( inputs_storage=utils.CONTIGUOUS_ANY, outputs_storage=utils.CONTIGUOUS_BUFFER supports_resize=True, ) # DQ-Linear def register_linear_qta8a_qga4w_op(): return OpFeatures( inputs_storage=[ utils.CONTIGUOUS_ANY, # input utils.CONTIGUOUS_BUFFER, # mat1 scales utils.CONTIGUOUS_BUFFER, # mat1 zeros utils.NO_STORAGE, # weight (prepacked) utils.NO_STORAGE, # group size (non tensor) utils.CONTIGUOUS_BUFFER, # mat2 scales utils.CONTIGUOUS_BUFFER, # mat2 zeros ], supports_resize=True, supports_prepacking=True, ) ``` The 3 synchronization rules are inferred from the defined `inputs_storage` and `outputs_storage`: * If no `outputs_storage` is defined, then assume that the `outputs_storage` is the same as the first `TensorRepSet` in `inputs_storage`. This also implies that the primary input and output need to be synced * If `inputs_storage` only contains a single `TensorRepSet`, it is assumed that all input tensors need to be synchronized. * Similarly, if `outputs_storage` only contains a single `TensorRepSet`, it is assumed that all output tensors need to be synchronized * If the first entry in `inputs_storage` and `outputs_storage` are the same, assume that the primary input and output need to be synced. ### `tag_memory_metadata_pass.py` The `tag_memory_metadata_pass.py` maintains the same scope and behaviour as before. However, it is almost re-written completely to use `OpRepSet` utility class. However, it goes through the same steps as before: * For each operator, determine the initial `OpRepSets` * Constrain the initial `OpRepSets` by checking any existing representations of input tensors, and checking future uses of the output tensor(s) to try and reduce the number of representation transitions needed * Set the representation of each input/output tensor in the operator. If an input tensor requires a different representation than it currently has, insert a clone node to transition the arg to the required representation. Reviewed By: trivedivivek Differential Revision: D79116560 --- .../vulkan/_passes/insert_prepack_nodes.py | 2 +- .../_passes/remove_local_scalar_dense_ops.py | 4 +- .../vulkan/_passes/tag_memory_meta_pass.py | 607 ++++++++------ backends/vulkan/op_registry.py | 570 +++++-------- .../vulkan/partitioner/vulkan_partitioner.py | 70 +- .../serialization/vulkan_graph_builder.py | 12 +- backends/vulkan/test/test_vulkan_delegate.py | 112 ++- backends/vulkan/utils.py | 761 +++++++++++++++++- 8 files changed, 1422 insertions(+), 716 deletions(-) diff --git a/backends/vulkan/_passes/insert_prepack_nodes.py b/backends/vulkan/_passes/insert_prepack_nodes.py index ed736438cbb..c45ed4ea25d 100644 --- a/backends/vulkan/_passes/insert_prepack_nodes.py +++ b/backends/vulkan/_passes/insert_prepack_nodes.py @@ -35,7 +35,7 @@ def insert_prepack_nodes(program: ExportedProgram) -> ExportedProgram: # Mark that this node is going to be represented as a TensorRef type in the # Vulkan compute graph. This annotation is used in later graph passes. - node.meta["vkdg_tensorref"] = True + node.meta["etvk_tensorref"] = True # Get the list of node users that do not handle their own prepacking nodes_to_replace_input = [] diff --git a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py index 4c4b8c265af..6ce3572ec0c 100644 --- a/backends/vulkan/_passes/remove_local_scalar_dense_ops.py +++ b/backends/vulkan/_passes/remove_local_scalar_dense_ops.py @@ -52,7 +52,7 @@ def tag_node_if_scalar_tensor(node: torch.fx.Node) -> None: for user in node.users: if node_is_local_scalar_dense_chain(user): - node.meta["vkdg_is_scalar_tensor"] = True + node.meta["etvk_is_scalar_tensor"] = True def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) -> None: @@ -74,7 +74,7 @@ def remove_local_scalar_dense_chain(graph: torch.fx.Graph, node: torch.fx.Node) if replace_node.args[0].meta["val"].numel() == 1: replace_node = replace_node.args[0] assert isinstance(replace_node, torch.fx.Node) - assert replace_node.meta.get("vkdg_is_scalar_tensor", True) + assert replace_node.meta.get("etvk_is_scalar_tensor", True) with graph.inserting_after(node): node.replace_all_uses_with(replace_node) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 0bd8dae0b66..db53cc666a8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -5,13 +5,15 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Any, Optional, Set +import operator + +from typing import Any import executorch.backends.vulkan.utils as utils import torch -from executorch.backends.vulkan.op_registry import get_op_features, has_impl +from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -27,23 +29,16 @@ logger.setLevel(logging.INFO) -def set_memory_metadata( - node: torch.fx.Node, storage: VkStorageType, layout: VkMemoryLayout -) -> None: - utils.set_node_spec_attr(node, "vk_storage_type", storage) - utils.set_node_spec_attr(node, "vk_memory_layout", layout) - - def insert_transition_node( graph_module: torch.fx.GraphModule, node: torch.fx.Node, arg: torch.fx.Node, - storage: VkStorageType, - layout: VkMemoryLayout, + arg_node_repr: utils.TensorRepr, ) -> None: """ - Insert a clone node to copy the original tensor to a tensor with the desired storage - type and memory layout. + Insert a clone node to transition the tensor associated with `arg` to a tensor with + the requested representation `arg_node_repr`, and use the cloned node as an argument + to `node` instead of `arg`. """ with graph_module.graph.inserting_before(node): clone_node = graph_module.graph.create_node( @@ -54,30 +49,80 @@ def insert_transition_node( clone_node.meta["val"] = arg.meta["val"] clone_node.meta["spec"] = TensorSpec.from_tensor(clone_node.meta["val"]) clone_node.meta["spec"].const = False - set_memory_metadata(clone_node, storage, layout) + utils.set_node_repr(clone_node, arg_node_repr) arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y) -class TagMemoryMetaPass(ExportPass): +def set_arg_node_repr_or_transition( + graph_module: torch.fx.GraphModule, + op_node: torch.fx.Node, + arg_i: int, + arg_node_repr: utils.TensorRepr, + dirty: bool, +) -> bool: """ - There are a variety of ways that tensors can be represented in Vulkan. The two main - descriptors for how a tensor is laid out in memory is: + Does one of following: + 1. Sets the `node_repr` of the argument at `arg_i` of `op_node` if the argument node + does not currently have a `node_repr` + 2. No-op if the current `node_repr` is already the same as the requested represetnation. + 3. Insert a transition node to create a copy of the argument with the desired `node_repr` + if the current `node_repr` is different than what is needed. + """ + arg_node = op_node.args[arg_i] + + def single_node_impl(node: torch.fx.Node) -> bool: + # Case where the arg node has not been touched yet; in this case, simply set it and + # return. + if not utils.has_node_repr(node): + utils.set_node_repr(node, arg_node_repr) + return False + + # Case where the current node representation is the same as the new one. + cur_node_repr = utils.get_node_repr(node) + assert isinstance(cur_node_repr, utils.TensorRepr) + + if cur_node_repr == arg_node_repr: + return False + + if not dirty: + logger.info( + f"[Vulkan Delegate] Inserting transition(s) for {op_node.format_node()}:" + ) + + # Existing node representation is different; insert a transition node + # Currently, the transition node insertion logic can only handle single tensor nodes + assert utils.is_single_tensor_node(node) + insert_transition_node(graph_module, op_node, node, arg_node_repr) + + logger.info(f" arg {arg_i} ({node}): ({cur_node_repr}) -> ({arg_node_repr})") + + return True + + if isinstance(arg_node, torch.fx.Node): + return single_node_impl(arg_node) + elif isinstance(arg_node, (list, tuple)): + ret: bool = False + for n in arg_node: + assert isinstance(n, torch.fx.Node) + assert utils.is_single_tensor_node(n) + ret = single_node_impl(n) or ret - 1. Storage Type (buffer or texture) - 2. Memory Layout (which dim is packed along a texel / has a stride of 1, etc.) + return ret - Due to the differences between buffers and textures, and the differences between - different memory layouts, an implementation for an operator may only support a - specific set of (storage type, memory layout) combinations. + raise NotImplementedError(f"Unhandled node type {arg_node}") - Furthermore, if an operator implementation supports multiple (storage type, memory - layout) combinations, there may be a "preferred" setting which results in optimal - performance. - This pass is responsible for ensuring that all tensors participating in an operator - call have a valid/optimal (storage type, memory layout) setting, and insert - transition operators to transfer input tensors to the correct memory settings when - necessary. +class TagMemoryMetaPass(ExportPass): + """ + Operator implementations in the Vulkan delegate may require that input and output + tensors use a specific representation. Representation in this case refers to a + combination of storage type (buffer or texture) and memory layout (width, height, or + channels packed). + + The tag memory metadata pass is responsible for marking each tensor in the graph + with the appropriate representation to use. It is also responsible for inserting + operators to transition argument tensors to a required/compatible representation if + a mismatch has been detected. """ def __init__( @@ -91,241 +136,331 @@ def __init__( self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits - def propose_node_storage( # noqa: C901 - self, - node: torch.fx.Node, - ) -> Optional[VkStorageType]: + # Magic number to limit "lookahead" when tracing through users of an operator + # to constrain the representation of its arguments/outputs. + self.max_trace_search_depth = 20 + + def is_valid_op_node(self, node: Any) -> bool: """ - Uses the operator registry to determine the storage type that should be used for - a given node. The storage type is determined with the following priorities: - 1. In some cases, a tensor involved in the computation may be too large to be - represented as a texture. If this is the case, the node is "opinionated" and - buffer representation must be used. - 1. If the operator called by the node indicates an optimal storage type, or only - supports a single storage type, use that storage type. If either is true, - then the node is considered to be opinionated as well. If multiple storage - and no preferred storage type is indicated, then the node is not opinionated; - go to the next step. - 2. If the node's arguments already have memory metadata annotations, then - preserve the settings of the first argument. Otherwise, proceed to the next - step. - 3. Recursively search the node's uses to see if any subsequent uses are - opinionated; inherit the settings of the first opinionated node. If no - opinionated user can be found, then proceed to the last step. - 4. Use the default storage type setting. + Fails the check for: + * nodes that are not associated with a tensor + * nodes that are associated with a constant tensor + * nodes that are not associated with a supported operator """ - if not utils.is_tensor_node(node): - return None - - # The node may have an input/output tensor that is too big to be stored in a - # texture. In this case, buffer storage must be used. Note that the partitioner - # has already checked for the fact that buffer storage is supported by the - # operator. - if len(utils.possible_node_memory_layouts(node, self.texture_limits)) == 0: - return VkStorageType.BUFFER - - valid_storage_types: Set[VkStorageType] = utils.all_storage_types - - # pyre-ignore - if has_impl(node.target): - # pyre-ignore - features = get_op_features(node.target) - valid_storage_types = features.supported_storage_types() - storage = features.propose_storage_type() - if storage is not None: - return storage - - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): - storage = utils.get_node_storage_type(arg) - # Some operators which return multiple output tensors may specify a - # different storage type for each output. In this case, the storage type - # for the first output is used. - if isinstance(storage, (list, tuple)): - storage = storage[0] - if storage is not None and storage in valid_storage_types: - return storage - - # If no storage type has been resolved yet, assume the optimal storage type of - # the first opinionated user. This search is recursive. - for user in node.users: - storage = self.propose_node_storage(user) - # See above - if isinstance(storage, (list, tuple)): - storage = storage[0] - if storage is not None: - return storage - - if self.default_storage in valid_storage_types: - return self.default_storage - else: - return next(iter(valid_storage_types)) + if not isinstance(node, torch.fx.Node) or not utils.is_tensor_node(node): + return False + if node.meta.get("etvk_tensorref", False): + return False + if not has_impl(node.target): + return False - def propose_node_layout( - self, - node: torch.fx.Node, - storage: VkStorageType, - ) -> Optional[VkMemoryLayout]: + return True + + def is_non_constant_tensor_node(self, node: Any) -> bool: """ - Performs the same steps as propose_node_storage, but detects the memory layout - that should be used for the specific storage type. The same prioritization logic - is applied. + Fails the check for: + * Nodes that are not associated with tensor values + * Nodes associated with constant tensors + * """ - if not utils.is_tensor_node(node): - return None - - valid_layouts: Set[VkMemoryLayout] = utils.all_memory_layouts - # pyre-ignore - if has_impl(node.target): - # pyre-ignore - features = get_op_features(node.target) - valid_layouts = features.supported_memory_layouts(storage) - layout = features.propose_memory_layout(storage) - if layout is not None: - return layout - - for arg in node.args: - if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg): - layout = utils.get_node_memory_layout(arg) - # Some operators which return multiple output tensors may specify a - # different memory layout for each output. In this case, the storage - # type for the first output is used. - if isinstance(layout, (list, tuple)): - layout = layout[0] - if layout is not None and layout in valid_layouts: - return layout - - # If no memory layout has been resolved yet, assume the optimal layout of the - # first opinionated user. This search is recursive. - for user in node.users: - layout = self.propose_node_layout(user, storage) - # See above comment - if isinstance(layout, (list, tuple)): - layout = layout[0] - if layout is not None: - return layout - - # As a last resort, return the default storage type that should be used. - if self.default_layout in valid_layouts: - return self.default_layout - else: - return next(iter(valid_layouts)) - - def should_annotate(self, node) -> bool: if isinstance(node, torch.fx.Node): if not utils.is_tensor_node(node): return False - - # Storage type and memory layout for tensorref will be determined at runtime - # so there's no use in setting those attributes ahead of time. - if node.meta.get("vkdg_tensorref", False): + if node.meta.get("etvk_tensorref", False): return False + return True - # Skip annotating output node. The output tensors should be annotated by the - # time the output node is observed. - if node.op == "output": - return False - elif isinstance(node, (list, tuple)): - return all( - isinstance(n, torch.fx.Node) and self.should_annotate(n) for n in node - ) + if isinstance(node, (tuple, list)): + for n in node: + if not isinstance(n, torch.fx.Node): + return False + if not self.is_non_constant_tensor_node(n): + return False + + return True + + # Return false by default + return False + + def get_node_cached_repsets(self, op_node: torch.fx.Node) -> utils.OpRepSets: + """ + Implements a cache layer for getting the OpRepSets for a given operator node. + """ + assert self.is_valid_op_node(op_node) + + if "etvk_node_repsets" in op_node.meta: + op_repsets = op_node.meta["etvk_node_repsets"] + assert isinstance(op_repsets, utils.OpRepSets) + return op_repsets else: - return False + # Special case for getitem - set the input and output to the repset of the + # tensor value being extracted + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) + + arg_node_repsets = self.get_node_cached_repsets(src_node) + out_tensor_repset = arg_node_repsets.get_out_repset(idx) + + op_repsets = utils.OpRepSets( + utils.TensorRepSetList(out_tensor_repset), + utils.TensorRepSetList(out_tensor_repset), + op_node, + self.texture_limits, + ) + else: + features: OpFeatures = get_op_features(op_node.target) # noqa + op_repsets = features.make_op_repsets(op_node, self.texture_limits) - return True + op_node.meta["etvk_node_repsets"] = op_repsets + return op_repsets - def should_delay_annotation(self, node: torch.fx.Node) -> bool: - # For prepack nodes, delay setting the storage type and memory layout as long as - # possible. This is to minimize the number of transitions, since it can be - # difficult to predict what storage type and memory layout should be used at the - # time the prepack node is observed. - return node.target == exir_ops.edge.et_vk.prepack.default + def get_arg_tensor_source_repset( + self, op_node: torch.fx.Node, arg_i: int + ) -> utils.TensorRepSet: + """ + Get the "source RepSet" for the tensor argument at index `arg_i` of `op_node`. + The source repset is obtained in one of two ways: - def set_or_transition_arg_node( + 1. If the tensor argument already has a representation determined for it, return + a repset that contains that representation. + 2. Otherwise, return the output repset of the operator that produces the tensor + """ + arg_node = op_node.args[arg_i] + + # Special case for cat - use the first tensor in the list as representative + if isinstance(arg_node, list): + arg_node = arg_node[0] + + if utils.has_node_repr(arg_node): + arg_node_repr = utils.get_node_repr(arg_node) + assert isinstance(arg_node_repr, utils.TensorRepr) + return utils.make_tensor_repset(arg_node_repr) + elif self.is_valid_op_node(arg_node): + # Special case for getitem - propagate the node representation of the original node + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) + + src_node_repsets = self.get_node_cached_repsets(src_node) + return src_node_repsets.get_out_repset(idx) + + src_node_repsets = self.get_node_cached_repsets(arg_node) + return src_node_repsets.get_out_repset(0) + + # default return + return utils.ANY_STORAGE + + def constrain_repset_with_user( self, - i: int, - arg: torch.fx.Node, - node: torch.fx.Node, - graph_module: torch.fx.GraphModule, - dirty: bool, - ) -> bool: - assert isinstance(arg, torch.fx.Node) - - storage = utils.get_node_storage_type(node) - assert storage is not None - layout = utils.get_node_memory_layout(node) - assert layout is not None - - arg_storage = utils.get_node_storage_type(arg) - arg_layout = utils.get_node_memory_layout(arg) - - if arg_storage is None: - utils.set_node_spec_attr(arg, "vk_storage_type", storage) - arg_storage = storage - if arg_layout is None: - utils.set_node_spec_attr(arg, "vk_memory_layout", layout) - arg_layout = layout - - if arg_storage == storage and arg_layout == layout: - return False + current_node: torch.fx.Node, + arg_i: int, + arg_repset: utils.TensorRepSet, + search_depth: int = 0, + ) -> utils.TensorRepSet: + """ + Attempts to constrain `arg_repset` based on the required repset of the argument + at index `arg_i` of `current_node`. This tries to find a representation for the + argument that can be used for as long as possible without needing a transition. + """ + # The repset is already constrained; return it + if arg_repset.is_constrained(): + return arg_repset + + # The current node is not a valid op node, so no OpRepSets object can be created + # for it. + if not self.is_valid_op_node(current_node): + return arg_repset + + cur_node_repsets = self.get_node_cached_repsets(current_node) + + # Intersect with the repset required by the current operator; otherwise, return + # since a transition will be required anyways + req_arg_repset = cur_node_repsets.get_arg_repset(arg_i) + if req_arg_repset.any_in_common(arg_repset): + arg_repset = arg_repset.make_intersect(req_arg_repset) + else: + return arg_repset - if not dirty: - logger.info( - f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:" - ) + # Check if the argument at `arg_i` will influence the output representation of + # the current operator. + repset_propagates_to_output = cur_node_repsets.sync_primary_io_repr and ( + cur_node_repsets.sync_args_repr or arg_i == cur_node_repsets.primary_arg_idx + ) - insert_transition_node(graph_module, node, arg, storage, layout) + # If not, then no point in continuing to trace the users of the current node + if not repset_propagates_to_output: + return arg_repset - logger.info( - f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})" + return self.trace_node_users_to_constrain_repset( + current_node, arg_repset, search_depth ) - return True - - def set_or_transition_arg( + def trace_node_users_to_constrain_repset( self, - i: int, - arg: Any, - node: torch.fx.Node, - graph_module: torch.fx.GraphModule, - dirty: bool, - ) -> bool: - if isinstance(arg, torch.fx.Node): - return self.set_or_transition_arg_node(i, arg, node, graph_module, dirty) - elif isinstance(arg, (list, tuple)): - need_transition = False - for arg_node in arg: - need_transition = ( - self.set_or_transition_arg_node( - i, arg_node, node, graph_module, need_transition - ) - or need_transition + origin_node: torch.fx.Node, + repset: utils.TensorRepSet, + search_depth: int = 0, + ) -> utils.TensorRepSet: + """ + For an ambiguous repset, try to constrain the repset by tracing the required + repsets of the users of `origin_node`. The idea is to try to find a representation + that can be used the longest without needing user nodes to insert a transition + for its arguments. + """ + # Optionally limit the search depth to improve export time + if self.max_trace_search_depth is not None: + if search_depth > self.max_trace_search_depth: + return repset + + users_to_trace = origin_node.users + + sync_outs_repr = True + if self.is_valid_op_node(origin_node): + sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr + + if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr: + users_to_trace = [] + for usage_node in origin_node.users: + if usage_node.target == operator.getitem and usage_node.args[1] == 1: + users_to_trace.append(usage_node) + + for usage_node in users_to_trace: + arg_i_in_user = None + for i in range(len(usage_node.args)): + if origin_node == usage_node.args[i]: + arg_i_in_user = i + break + + if arg_i_in_user is not None: + repset = self.constrain_repset_with_user( + usage_node, arg_i_in_user, repset, search_depth + 1 ) - return need_transition - else: - return False - # noqa - def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - for node in graph_module.graph.nodes: - if not self.should_annotate(node) or self.should_delay_annotation(node): - continue + if repset.is_constrained(): + return repset + + return repset + + def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None: + """ + Attempts to constrain the repset of the argument at index `arg_i` of the op + associated with `op_repsets`. Does this with two stages: + + 1. First, account for any existing representation that has already been determined + for the argument. If no existing representation has been determined, then use + the output repset of the operator that produces the argument. + 2. Then, try to trace through the users of the argument to find a representation + that can be used for as long as possible without needing a transition. + """ + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) + + arg_repset = op_repsets.get_arg_repset(arg_i) + if arg_repset.is_constrained(): + return arg_repset + + arg_node = op_repsets.op_node.args[arg_i] + + if isinstance(arg_node, list): + arg_node = arg_node[0] + + arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: + # For most ops, constraining the argument repsets will also contrain the output + # repset due to OpRepSets maintaining synchronization rules. + for i in range(len(op_repsets.op_node.args)): + if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): + self.constrain_op_arg_repset(i, op_repsets) + + # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there + # is no need to constrain output repsets explicitly. Currently, the exceptions + # (i.e. choose qparams) already define constrined repsets for the output, so + # there is again no need to explicitly constrain the outputs. If an operator + # appears later on that does not sync input and output representations, and + # defines ambiguous repsets for the output tensor(s), then we will need to add + # additional logic to this function to constrain the output repsets separately + # from the input repsets. + + def set_op_node_tensor_reprs( + self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node + ) -> None: + """ + For an operator representated by `op_node`, get the OpRepSets associated with + the operation and try to constrain the repsets by accounting for existing + representations and tracing through the users of the operator. + + Then, determine a tensor representation for all tensors participating in the + operation and mark it in the node metadata. If the requested representation is + different than an already determined representation, then insert a transition + node to create a copy of the tensor with the desired representation. + """ + if not self.is_valid_op_node(op_node): + return + + # Special case for getitem - propagate the node representation of the original node + if op_node.target == operator.getitem: + src_node = op_node.args[0] + assert isinstance(src_node, torch.fx.Node) + idx = op_node.args[1] + assert isinstance(idx, int) - storage = self.propose_node_storage(node) - layout = self.propose_node_layout(node, storage) + arg_node_repr = utils.get_node_repr(src_node) + assert isinstance(arg_node_repr, list) + utils.set_node_repr(op_node, arg_node_repr[idx]) + return - set_memory_metadata(node, storage, layout) + # Get a "fresh" OpRepSets object instead of using the cache. Do this because this + # class instance will go through the constraining process which may modify it. + features: OpFeatures = get_op_features(op_node.target) + op_repsets = features.make_op_repsets(op_node, self.texture_limits) - need_transition = False - for i, arg in enumerate(node.args): - if not self.should_annotate(arg): - continue + self.constrain_op_repsets(op_repsets) - need_transition = ( - self.set_or_transition_arg( - i, arg, node, graph_module, need_transition + args_repr_list, outs_repr_list = op_repsets.pick_representations() + + if len(outs_repr_list) == 1: + utils.set_node_repr(op_node, outs_repr_list[0]) + else: + utils.set_node_repr(op_node, outs_repr_list) + + transitions_inserted = False + for i, arg_node in enumerate(op_node.args): + if not self.is_non_constant_tensor_node(arg_node): + continue + + arg_node_repr = args_repr_list[i] + + if isinstance(arg_node, torch.fx.Node): + transitions_inserted = ( + set_arg_node_repr_or_transition( + graph_module, op_node, i, arg_node_repr, transitions_inserted ) - or need_transition + or transitions_inserted ) + elif isinstance(arg_node, (list, tuple)): + for n in arg_node: + assert isinstance(n, torch.fx.Node) + assert utils.is_single_tensor_node(n) + transitions_inserted = ( + set_arg_node_repr_or_transition( + graph_module, + op_node, + i, + arg_node_repr, + transitions_inserted, + ) + or transitions_inserted + ) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + for node in graph_module.graph.nodes: + self.set_op_node_tensor_reprs(graph_module, node) return PassResult(graph_module, True) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 33ed3150535..2e0be1d68d7 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -8,22 +8,14 @@ import operator -from typing import Callable, Dict, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa -import torch +import executorch.backends.vulkan.utils as utils -from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( - VkMemoryLayout, - VkStorageType, -) +import torch -from executorch.backends.vulkan.utils import ( - all_memory_layouts, - all_packed_dims, - PackedDim, -) from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.dialects.edge._ops import EdgeOpOverload @@ -38,156 +30,60 @@ def allow_node(node: torch.fx.Node) -> bool: return True -class TextureImplFeatures: - __slots__ = [ - "valid_packed_dims", - "uses_axis_map", - ] - - def __init__( - self, - uses_axis_map: bool = False, - valid_packed_dims: Optional[Set[PackedDim]] = None, - ): - self.uses_axis_map: bool = uses_axis_map - self.valid_packed_dims = set() - if valid_packed_dims is not None: - self.valid_packed_dims = valid_packed_dims - - def valid_memory_layouts(self) -> Set[VkMemoryLayout]: - """ - Derive the set of memory layouts supported by the texture implementation based - on the valid packed dimensions. - """ - layouts = set() - - if PackedDim.WIDTH in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_WIDTH_PACKED) - - if PackedDim.HEIGHT in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_HEIGHT_PACKED) - - if PackedDim.CHANNELS in self.valid_packed_dims: - layouts.add(VkMemoryLayout.TENSOR_CHANNELS_PACKED) - - return layouts - - class OpFeatures: __slots__ = [ - # None or TextureImplFeatures to specify implementation details of the texture - # based operator implementation. - "texture_impl", - # bool indicating if the operator has a buffer based implementation. - "buffer_impl", + # Sets of possible (storage types, memory layouts) to use for the input tensor(s) + "inputs_storage", + # Sets of possible (storage types, memory layouts) to use for the output tensor(s) + "outputs_storage", # bool indicating if the operator has a resize function, which allows it to - # support dynamic shape tensors. - "resize_fn", - # Optimal - "optimal_storage", - "optimal_layout", + # support models with dynamic shape + "supports_resize", # bool indicating if the operator handles its own prepacking. If this is True, # then the insert_prepack_nodes pass will not insert prepack nodes for the args # of the op. - "handles_own_prepacking", - # Optional dictionary to specify a custom function to calculate the required - # image extents for a particular argument index. - "skip_limits_check", + "supports_prepacking", # Optional check function used during partitioning to determine if a node's # inputs are supported by the operator implementation. - "check_node_fn", + "are_node_inputs_supported_fn", ] def __init__( self, - texture_impl: Optional[TextureImplFeatures] = None, - buffer_impl: bool = False, - resize_fn: bool = False, - optimal_storage: Optional[VkStorageType] = None, - optimal_layout: Optional[VkMemoryLayout] = None, - handles_own_prepacking: bool = False, - skip_limits_check: Optional[Set[int]] = None, - check_node_fn: Optional[Callable] = None, + inputs_storage: Optional[ + Union[utils.TensorRepSet, List[utils.TensorRepSet]] + ] = None, + outputs_storage: Optional[ + Union[utils.TensorRepSet, List[utils.TensorRepSet]] + ] = None, + supports_resize: bool = False, + supports_prepacking: bool = False, + are_node_inputs_supported_fn: Optional[Callable] = allow_node, ): - self.texture_impl: Optional[TextureImplFeatures] = texture_impl - self.buffer_impl: bool = buffer_impl - self.resize_fn: bool = resize_fn - self.optimal_storage: Optional[VkStorageType] = optimal_storage - self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout - self.handles_own_prepacking: bool = handles_own_prepacking - - self.skip_limits_check: Set[int] = set() - if skip_limits_check is not None: - self.skip_limits_check = skip_limits_check - - self.check_node_fn: Callable = allow_node - if check_node_fn is not None: - self.check_node_fn = check_node_fn - - def propose_storage_type(self) -> Optional[VkStorageType]: - """ - Propose a storage type that should be used for this operator. A proposal can be - made if one of the following is true: - 1. The operator specifies an optimal storage type - 2. Only one storage type is supported. - - If both storage types are supported and no optimal storage type is specified, - then None is returned to indicate that there is no preference in storage type. - """ - if self.optimal_storage is not None: - return self.optimal_storage - - if self.texture_impl is not None and not self.buffer_impl: - return VkStorageType.TEXTURE_3D - elif self.buffer_impl and self.texture_impl is None: - return VkStorageType.BUFFER - - return None - - def supported_storage_types(self) -> Set[VkStorageType]: - """ - Return the set of storage types supported by this operator. - """ - storage_types = set() - if self.texture_impl is not None: - storage_types.add(VkStorageType.TEXTURE_3D) - if self.buffer_impl: - storage_types.add(VkStorageType.BUFFER) - - return storage_types - - def propose_memory_layout(self, storage: VkStorageType) -> Optional[VkMemoryLayout]: - """ - Given a storage type as a precondition, propose a memory layout that should be - used for this operator. A proposal can be made if one of the following is true: - 1. The operator specifies an optimal memory layout - 2. Only one memory layout is supported. - - If multiple memory layouts are supported and no optimal memory layout is - specified then return None to indicate that the "best" memory layout for the - operator is ambiguous. - """ - if self.optimal_layout is not None: - return self.optimal_layout - - if storage == VkStorageType.TEXTURE_3D: - assert self.texture_impl is not None - possible_layouts = self.texture_impl.valid_memory_layouts() - if len(possible_layouts) == 1: - return next(iter(possible_layouts)) - - return None - - def supported_memory_layouts(self, storage: VkStorageType) -> Set[VkMemoryLayout]: - """ - Return the set of memory layouts supported by this operator for a given storage - type. - """ - if storage == VkStorageType.TEXTURE_3D: - assert self.texture_impl is not None - return self.texture_impl.valid_memory_layouts() - else: - return all_memory_layouts + self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( + inputs_storage if inputs_storage is not None else [] + ) + self.outputs_storage: utils.TensorRepSetList = utils.TensorRepSetList( + outputs_storage if outputs_storage is not None else [] + ) + + # If output storage is not set, assume that it is derived from the first input + if self.outputs_storage.any_is_empty(): + self.outputs_storage = utils.TensorRepSetList(self.inputs_storage[0]) + + self.supports_resize = supports_resize + self.supports_prepacking = supports_prepacking + + self.are_node_inputs_supported_fn = are_node_inputs_supported_fn + + def make_op_repsets( + self, + op_node: torch.fx.Node, + texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS, + ) -> utils.OpRepSets: + return utils.OpRepSets( + self.inputs_storage, self.outputs_storage, op_node, texture_limits + ) ####################### @@ -204,8 +100,7 @@ def features_decorator(fn: Callable): def update_features_impl(op: OpKey): if op in vulkan_supported_ops: raise RuntimeError(f"[Vulkan delegate] duplicate registration of {op}!") - vulkan_supported_ops[op] = OpFeatures() - vulkan_supported_ops[op] = fn(vulkan_supported_ops[op]) + vulkan_supported_ops[op] = fn() if isinstance(aten_op, list): for op in aten_op: @@ -233,14 +128,11 @@ def update_features_impl(op: OpKey): torch.ops.aten.sym_constrain_range_for_size.default, ] ) -def register_ephemeral_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_ephemeral_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - return features @update_features( @@ -253,23 +145,13 @@ def register_ephemeral_op(features: OpFeatures): exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, exir_ops.edge.quantized_decomposed.quantize_per_token.default, exir_ops.edge.quantized_decomposed.dequantize_per_token.default, - exir_ops.edge.quantized_decomposed.choose_qparams.tensor, - exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, ] ) -def register_quantization_op(features: OpFeatures): - # Quantization requires buffer storage and width packing for scales/zero_points - # but we need to provide texture impl features for the partitioner to work properly - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims={ - PackedDim.WIDTH, - }, +def register_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.BUFFER - return features @update_features( @@ -278,39 +160,25 @@ def register_quantization_op(features: OpFeatures): exir_ops.edge.torchao.dequantize_affine.default, ] ) -def register_affine_quantization_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_affine_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - - return features @update_features( [ exir_ops.edge.torchao.choose_qparams_affine.default, + exir_ops.edge.quantized_decomposed.choose_qparams.tensor, + exir_ops.edge.quantized_decomposed.choose_qparams_per_token_asymmetric.default, ] ) -def register_choose_qparams_affine_op(features: OpFeatures): - # Currently only created a rudimentary buffer implementation for choose_qparams_affine - # since the reduction logic for blocks in texture3d is not trivial to implement in vulkan. - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={ - PackedDim.WIDTH, - }, +def register_torchao_quantization_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_BUFFER, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.BUFFER - - return features @update_features( @@ -329,13 +197,11 @@ def register_choose_qparams_affine_op(features: OpFeatures): exir_ops.edge.aten.ge.Tensor, ] ) -def register_binary_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_binary_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -358,24 +224,15 @@ def register_binary_op(features: OpFeatures): exir_ops.edge.aten.leaky_relu.default, ] ) -def register_unary_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, +def register_unary_op(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - return features @update_features(exir_ops.edge.aten._to_copy.default) -def register_to_copy_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, - ) - features.resize_fn = True - +def register_to_copy_op(): def check_to_copy_node(node: torch.fx.Node) -> bool: float_dtypes = [torch.float16, torch.float32] @@ -395,20 +252,15 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: return False - features.check_node_fn = check_to_copy_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_to_copy_node, + ) @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) -def register_to_copy_dim_order_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims=all_packed_dims, - ) - features.buffer_impl = True - features.resize_fn = True - +def register_to_copy_dim_order_op(): # Currently there is no "real" implementation for to_dim_order_copy, but it can be # removed as long as the operator is not changing the dtype, i.e. the operator call # is modifying the dim order only. Therefore, check that the input and output dtypes @@ -426,9 +278,11 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool: return True - features.check_node_fn = check_dim_order_copy_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_dim_order_copy_node, + ) @update_features( @@ -439,20 +293,12 @@ def check_dim_order_copy_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.linear.default, ] ) -def register_mm_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=True, - valid_packed_dims={ - PackedDim.WIDTH, - PackedDim.CHANNELS, - }, +def register_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( @@ -461,37 +307,46 @@ def register_mm_op(features: OpFeatures): exir_ops.edge.et_vk.linear_qcs4w.default, ] ) -def register_int8_mm_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_int8_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, ) - features.buffer_impl = True - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( [ exir_ops.edge.et_vk.linear_weight_int4.default, + ] +) +def register_int4_mm_op(): + return OpFeatures( + inputs_storage=utils.CONTIGUOUS_ANY, + supports_resize=True, + supports_prepacking=True, + ) + + +@update_features( + [ exir_ops.edge.et_vk.linear_qta8a_qga4w.default, ] ) -def register_int4_mm_op(features: OpFeatures): - features.buffer_impl = True - features.texture_impl = TextureImplFeatures( - uses_axis_map=False, - valid_packed_dims={PackedDim.WIDTH}, +def register_dqlinear_op(): + return OpFeatures( + inputs_storage=[ + utils.CONTIGUOUS_ANY, # input + utils.CONTIGUOUS_BUFFER, # mat1 scales + utils.CONTIGUOUS_BUFFER, # mat1 zeros + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # group size (non tensor) + utils.CONTIGUOUS_BUFFER, # mat2 scales + utils.CONTIGUOUS_BUFFER, # mat2 zeros + ], + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - features.skip_limits_check = {1} - return features @update_features( @@ -500,12 +355,11 @@ def register_int4_mm_op(features: OpFeatures): exir_ops.edge.aten._softmax.default, ] ) -def register_softmax_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_softmax_op(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -516,25 +370,24 @@ def register_softmax_op(features: OpFeatures): exir_ops.edge.aten.amin.default, ] ) -def register_reduce_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, - ) - features.resize_fn = True - +def register_reduce_op(): def check_reduce_node(node: torch.fx.Node) -> bool: dim_list = node.args[1] if isinstance(dim_list, list) and len(dim_list) != 1: return False - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + if len(node.args) > 2: + keepdim = node.args[2] + if isinstance(keepdim, bool) and not keepdim: + return False return True - features.check_node_fn = check_reduce_node - return features + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, + are_node_inputs_supported_fn=check_reduce_node, + ) @update_features( @@ -543,12 +396,11 @@ def check_reduce_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.max_pool2d_with_indices.default, ] ) -def register_2d_pool_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_2d_pool_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -557,28 +409,21 @@ def register_2d_pool_op(features: OpFeatures): exir_ops.edge.et_vk.conv_with_clamp.default, ] ) -def register_convolution_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_convolution_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED - features.handles_own_prepacking = True - features.skip_limits_check = {1, 2} - return features @update_features("llama::sdpa_with_kv_cache") -def register_sdpa_with_kv_cache_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_sdpa_with_kv_cache_op(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, + supports_prepacking=True, ) - features.resize_fn = True - features.optimal_storage = VkStorageType.TEXTURE_3D - features.optimal_layout = VkMemoryLayout.TENSOR_WIDTH_PACKED - features.handles_own_prepacking = True - return features @update_features( @@ -587,23 +432,19 @@ def register_sdpa_with_kv_cache_op(features: OpFeatures): "llama::custom_sdpa", ] ) -def register_sdpa_ops(features: OpFeatures): - features.resize_fn = False - features.buffer_impl = False - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_sdpa_ops(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) -def register_rotary_emb_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.WIDTH}, +def register_rotary_emb_op(): + return OpFeatures( + inputs_storage=utils.WIDTH_PACKED_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features @update_features( @@ -614,25 +455,18 @@ def register_rotary_emb_op(features: OpFeatures): exir_ops.edge.aten.view_copy.default, ] ) -def register_view_ops(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_view_ops(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_resize=True, ) - features.resize_fn = True - return features # Fully featured transfer operators (i.e. operators that copy data from the input # tensor(s) to the output tensor(s)), which have memory layout agnostic implementations # for both texture and buffer storage types. @update_features(exir_ops.edge.aten.cat.default) -def register_cat_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, - ) - features.buffer_impl = True - features.resize_fn = True - +def register_cat_op(): def check_cat_node(node: torch.fx.Node) -> bool: inputs = node.args[0] if isinstance(inputs, (list, tuple)) and len(inputs) <= 3: @@ -640,9 +474,11 @@ def check_cat_node(node: torch.fx.Node) -> bool: return False - features.check_node_fn = check_cat_node - - return features + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, + are_node_inputs_supported_fn=check_cat_node, + ) # Fully featured transfer operators (i.e. operators that copy data from the input @@ -654,14 +490,11 @@ def check_cat_node(node: torch.fx.Node) -> bool: exir_ops.edge.aten.slice_copy.Tensor, ] ) -def register_transfer_ops(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_transfer_ops(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + supports_resize=True, ) - features.buffer_impl = True - features.resize_fn = True - - return features # Ops ported from PyTorch Vulkan backend. These ops commonly support channels @@ -688,14 +521,13 @@ def register_transfer_ops(features: OpFeatures): exir_ops.edge.et_vk.grid_priors.default, ] ) -def register_ported_op(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_ported_op(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, ) - return features -# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry becasue they support all packed dimensions +# Ops ported from PyTorch Vulkan backend. These ops are in a separate registry because they support all packed dimensions @update_features( [ # Shape Manipulation @@ -707,11 +539,10 @@ def register_ported_op(features: OpFeatures): exir_ops.edge.aten.split.Tensor, ] ) -def register_ported_op_all_packed_dims(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_ported_op_all_packed_dims(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, ) - return features # Ported ops that support their own prepacking. @@ -721,12 +552,11 @@ def register_ported_op_all_packed_dims(features: OpFeatures): exir_ops.edge.aten._native_batch_norm_legit_no_training.default, ] ) -def register_ported_ops_with_prepacking(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_ported_ops_with_prepacking(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + supports_prepacking=True, ) - features.handles_own_prepacking = True - return features @update_features( @@ -734,25 +564,16 @@ def register_ported_ops_with_prepacking(features: OpFeatures): exir_ops.edge.aten.native_group_norm.default, ] ) -def register_native_group_norm(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims={PackedDim.CHANNELS}, +def register_native_group_norm(): + return OpFeatures( + inputs_storage=utils.CHANNELS_PACKED_TEXTURE, + outputs_storage=[ + utils.CHANNELS_PACKED_TEXTURE, + utils.CONTIGUOUS_BUFFER, + utils.CONTIGUOUS_BUFFER, + ], + supports_prepacking=True, ) - features.handles_own_prepacking = True - - features.optimal_storage = [ - VkStorageType.TEXTURE_3D, - VkStorageType.BUFFER, - VkStorageType.BUFFER, - ] - - features.optimal_layout = [ - VkMemoryLayout.TENSOR_CHANNELS_PACKED, - VkMemoryLayout.TENSOR_WIDTH_PACKED, - VkMemoryLayout.TENSOR_WIDTH_PACKED, - ] - - return features # Ported ops that support their own prepacking. @@ -761,12 +582,11 @@ def register_native_group_norm(features: OpFeatures): exir_ops.edge.aten.native_layer_norm.default, ] ) -def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): - features.texture_impl = TextureImplFeatures( - valid_packed_dims=all_packed_dims, +def register_ported_ops_with_prepacking_all_dims(): + return OpFeatures( + inputs_storage=utils.ANY_TEXTURE, + supports_prepacking=True, ) - features.handles_own_prepacking = True - return features ####################### @@ -774,7 +594,7 @@ def register_ported_ops_with_prepacking_all_dims(features: OpFeatures): ####################### -def has_impl(target: OpKey) -> bool: +def has_impl(target: Any) -> bool: if not isinstance(target, str): if target not in vulkan_supported_ops: return target.name() in vulkan_supported_ops @@ -783,7 +603,7 @@ def has_impl(target: OpKey) -> bool: return target in vulkan_supported_ops -def get_op_features(target: OpKey) -> OpFeatures: +def get_op_features(target: Any) -> OpFeatures: if not isinstance(target, str): if target not in vulkan_supported_ops: # Try the op's name @@ -795,4 +615,4 @@ def get_op_features(target: OpKey) -> OpFeatures: def handles_own_prepacking(target: OpKey) -> bool: - return get_op_features(target).handles_own_prepacking + return get_op_features(target).supports_prepacking diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 9b76f6acd33..776d1d6e168 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -83,61 +83,18 @@ def op_node_is_compatible( # noqa: C901: Function is too complex return False, "no operator implementation" features = get_op_features(target) - # Check for high dimensional tensors - if utils.is_tensor_node(node) and utils.tensor_node_is_high_dim(node): - return False, "contains high dim tensor" - - valid_texture_layouts = utils.possible_node_memory_layouts( + # Get the possible tensor representations for each tensor participating in the + # this operator. Then check that all tensors are representable as either a + # buffer or texture. + op_repsets: utils.OpRepSets = features.make_op_repsets( node, self.texture_limits ) - can_use_buffers = utils.within_buffer_limit(node, self.buffer_limit) - for i, arg in enumerate(node.args): - if ( - isinstance(arg, torch.fx.Node) - and utils.is_tensor_node(arg) - and i not in features.skip_limits_check - ): - # Check for bool inputs - if utils.tensor_node_is_bool(arg): - return False, "contains bool tensor" - - # Check for high dimensional tensors - if utils.tensor_node_is_high_dim(arg): - return False, "contains high dim tensor" - - arg_texture_layouts = utils.possible_node_memory_layouts( - arg, self.texture_limits - ) - valid_texture_layouts = valid_texture_layouts.intersection( - arg_texture_layouts - ) - can_use_buffers = can_use_buffers and utils.within_buffer_limit( - arg, self.buffer_limit - ) - - op_available_layouts = features.supported_memory_layouts( - VkStorageType.TEXTURE_3D - ) - - can_use_texture = any( - layout in op_available_layouts for layout in valid_texture_layouts - ) - - # If there are no valid texture memory layouts, then buffer storage must be - # supported by the operator implementation. - if not can_use_texture: - if not can_use_buffers: - return ( - False, - f"op requires buffers that exceed the buffer limit ({self.buffer_limit})", - ) - - compatible = VkStorageType.BUFFER in features.supported_storage_types() - reason = "op is compatible" - if not compatible: - reason = "op requires buffers which is not supported by op impl" - return compatible, reason + if op_repsets.any_is_empty(): + return ( + False, + "No valid representations for a tensor in the operation", + ) return True, "Op is compatible" @@ -266,11 +223,11 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: assert features is not None - if not features.check_node_fn(node): + if not features.are_node_inputs_supported_fn(node): self.log_skip(node, "op args not supported") return False - if self.require_dynamic_shapes and not features.resize_fn: + if self.require_dynamic_shapes and not features.supports_resize: self.log_skip(node, "no dynamic shape support") return False @@ -331,7 +288,10 @@ def __init__( def ops_to_not_decompose( self, ep: ExportedProgram ) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]: - return (ops_not_to_decompose, None) + def filter_fn(node: torch.fx.Node) -> bool: + return True + + return (ops_not_to_decompose, filter_fn) def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible diff --git a/backends/vulkan/serialization/vulkan_graph_builder.py b/backends/vulkan/serialization/vulkan_graph_builder.py index cd876bd6305..b74a7fb1f8e 100644 --- a/backends/vulkan/serialization/vulkan_graph_builder.py +++ b/backends/vulkan/serialization/vulkan_graph_builder.py @@ -23,6 +23,7 @@ is_mutable_buffer_node, is_param_node, is_symint_node, + TensorRepr, ) from executorch.exir.backend.utils import DelegateMappingBuilder @@ -135,7 +136,7 @@ def maybe_add_constant_tensor(self, node: Node) -> int: def create_node_value(self, node: Node) -> int: # If the node has been marked as a scalar tensor, create a SymInt instead of a tensor - if is_symint_node(node) or node.meta.get("vkdg_is_scalar_tensor", False): + if is_symint_node(node) or node.meta.get("etvk_is_scalar_tensor", False): new_id = self.create_symint_value() self.node_to_value_ids[node] = new_id return new_id @@ -197,12 +198,11 @@ def create_tensor_value(self, spec: TensorSpec, constant_id: int = -1) -> int: storage_type = VkStorageType.DEFAULT_STORAGE memory_layout = VkMemoryLayout.DEFAULT_LAYOUT - if hasattr(spec, "vk_storage_type"): + if hasattr(spec, "etvk_node_repr"): # pyre-ignore[16] - storage_type = spec.vk_storage_type - if hasattr(spec, "vk_memory_layout"): - # pyre-ignore[16] - memory_layout = spec.vk_memory_layout + assert isinstance(spec.etvk_node_repr, TensorRepr) + storage_type = spec.etvk_node_repr.storage_type + memory_layout = spec.etvk_node_repr.memory_layout # Apply downcast logic before getting VK datatype effective_dtype = spec.dtype diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 926452dd388..4799a22882d 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -1790,25 +1790,21 @@ def forward(self, x): def test_vulkan_backend_large_linear_layer(self): class LinearModel(torch.nn.Module): - def __init__( - self, n_pca_basis: int, n_sh_basis: int, n_gaussians: int - ) -> None: + def __init__(self, large_out_channels: int) -> None: super(LinearModel, self).__init__() - self.fc1 = torch.nn.Linear( - n_pca_basis, (n_sh_basis + 3 + 3 + 4) * n_gaussians - ) + self.fc0 = torch.nn.Linear(1024, 128) + self.fc1 = torch.nn.Linear(128, large_out_channels) def forward(self, x: torch.Tensor): + x = self.fc0(x) out = self.fc1(x) return out - n_pca_basis = 64 - n_sh_basis = 6 - n_gaussians = 2**16 + large_out_channels = 2**16 self.lower_module_and_test_output( - LinearModel(n_pca_basis, n_sh_basis, n_gaussians), - (torch.ones(n_pca_basis),), + LinearModel(large_out_channels), + (torch.ones(1024),), ) def test_vulkan_backend_sym_size_int(self): @@ -2060,3 +2056,97 @@ def forward(self, x): self.lower_module_and_test_output( full_per_token_workflow_module, sample_inputs, atol=5e-3, rtol=5e-3 ) + + def test_vulkan_backend_different_required_reprs(self): + class ComplexModule(torch.nn.Module): + """ + This Module tests the tag memory metadata pass. The first few ops executed + are binary ops, which don't require any specific representation for input + and output tensors. + + This is followed by a linear layer, which requires the input tensor to be + width packed. + + Three linear layer outputs are then concatenated, and the result is passed + to a convolution layer which requires channels packing. Finally, group norm + is called and the output is postprocessed by a binary op before returning. + + In addition to requiring memory layout transitions between the linear and + conv stages, the module also contains ops which have "non-standard" + torch.fx.Nodes; cat will contain an argument node that is a list of nodes, + and group norm's node will be associated with multiple output tensors. + """ + + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 10) + self.conv = torch.nn.Conv2d( + in_channels=3, # Assuming concatenation triples the channels + out_channels=16, + kernel_size=3, + padding=1, + ) + self.group_norm = torch.nn.GroupNorm(num_groups=4, num_channels=16) + + def forward(self, x, a, b, c, d): + w = a + b + y = a + c + z = a + d + + b1 = x + y + b2 = x + z + b3 = x + w + + l1 = self.linear(b1).unsqueeze(0) + l2 = self.linear(b2).unsqueeze(0) + l3 = self.linear(b3).unsqueeze(0) + + concat = torch.cat([l1, l2, l3], dim=0) # Concatenate along channels + conv = self.conv(concat + a) + g = self.group_norm(conv.unsqueeze(0)) + return g + x + + complex_module = ComplexModule() + sample_inputs = ( + torch.rand(size=(10, 10), dtype=torch.float32), # x + torch.rand(size=(10, 10), dtype=torch.float32), # a + torch.rand(size=(10, 10), dtype=torch.float32), # b + torch.rand(size=(10, 10), dtype=torch.float32), # c + torch.rand(size=(10, 10), dtype=torch.float32), # d + ) + + self.lower_module_and_test_output(complex_module, sample_inputs) + + def test_vulkan_backend_cat_different_reprs(self): + class CustomComplexModule(torch.nn.Module): + """ + This test validates that the memory metadata tagging pass can handle + transitioning arguments to the cat operator. Linear layers require width + packing, while conv layers require channels packing. Before executing the + cat operator, all input tensors should use the same representation. + """ + + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 10) + self.linear2 = torch.nn.Linear(10, 10) + self.conv = torch.nn.Conv2d( + in_channels=4, # Assuming input b has 3 channels + out_channels=8, + kernel_size=3, + padding=1, + ) + + def forward(self, a, b): + x1 = self.linear1(a).unsqueeze(0) + x2 = self.linear2(a).unsqueeze(0) + y = self.conv(b) + return torch.cat([x1, x2, y], dim=0) + + custom_complex_module = CustomComplexModule() + sample_inputs = ( + torch.rand(size=(10, 10), dtype=torch.float32), # a + torch.rand(size=(4, 10, 10), dtype=torch.float32), # b + ) + + self.lower_module_and_test_output(custom_complex_module, sample_inputs) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 9086b2d0792..fa45063a4d3 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -4,8 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from enum import IntEnum -from typing import Optional, Set, Tuple +import operator +from typing import Any, List, Optional, Set, Tuple, Union import torch @@ -50,6 +50,9 @@ ## Node type determination ## +# Convenience type +MaybeNodeList = Union[torch.fx.Node, List[torch.fx.Node], Tuple[torch.fx.Node]] + def is_dequant_node(node: torch.fx.Node) -> bool: if node.op != "call_function": @@ -121,10 +124,42 @@ def is_symint_node(node: torch.fx.Node) -> bool: return False -def is_tensor_node(node: torch.fx.Node) -> bool: +def is_single_tensor_node(node: torch.fx.Node) -> bool: + """ + Returns true if the given node produces a single tensor value + """ + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], FakeTensor): + return True + + return False + + +def is_tensor_collection_node(node: Any) -> bool: + """ + Returns true if the given node produces a collection of tensor values + """ + if not isinstance(node, torch.fx.Node): + return False + + if "val" not in node.meta: + return False + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + return all(isinstance(x, FakeTensor) for x in node.meta["val"]) + + return False + + +def is_tensor_node(node: Any) -> bool: """ Returns true if the given node produces a tensor value, or a collection of tensor values """ + if not isinstance(node, torch.fx.Node): + return False + if "val" not in node.meta: return False @@ -137,6 +172,47 @@ def is_tensor_node(node: torch.fx.Node) -> bool: return False +def is_tensor_arg_node(node: Any) -> bool: + if isinstance(node, torch.fx.Node): + return is_tensor_node(node) + elif isinstance(node, (list, tuple)): + return all(is_tensor_node(n) for n in node) + + return False + + +def num_tensor_arg_nodes(node: torch.fx.Node) -> int: + """ + For a given node, return the number of argument nodes that are associated with + tensors. + """ + count = 0 + for arg_node in node.args: + if not isinstance(arg_node, torch.fx.Node): + continue + if is_tensor_node(arg_node): + count += 1 + + return count + + +def num_tensors_in_node(node: torch.fx.Node) -> int: + """ + Returns the number of tensors associated a given node + """ + if "val" not in node.meta: + return 0 + + if isinstance(node.meta["val"], FakeTensor): + return 1 + + if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): + if all(isinstance(x, FakeTensor) for x in node.meta["val"]): + return len(node.meta["val"]) + + return 0 + + def tensor_node_is_bool(node: torch.fx.Node) -> bool: """ Returns true if a given node contains a tensor with bool dtype @@ -151,6 +227,15 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: return False +def get_primary_arg_idx(self, node: torch.fx.Node) -> Optional[int]: + primary_arg_idx: Optional[int] = None + for i, arg_node in enumerate(node.args): + if self.is_non_constant_tensor_node(arg_node): + return i + + return primary_arg_idx + + ## ## Memory Layout, Storage Type Determination ## @@ -160,19 +245,6 @@ def tensor_node_is_bool(node: torch.fx.Node) -> bool: DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048) DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024) - -class PackedDim(IntEnum): - WIDTH = 0 - HEIGHT = 1 - CHANNELS = 2 - - -all_packed_dims: Set[PackedDim] = { - PackedDim.WIDTH, - PackedDim.HEIGHT, - PackedDim.CHANNELS, -} - all_storage_types: Set[VkStorageType] = { VkStorageType.BUFFER, VkStorageType.TEXTURE_3D, @@ -184,6 +256,9 @@ class PackedDim(IntEnum): VkMemoryLayout.TENSOR_CHANNELS_PACKED, } +MemoryLayoutSet = Set[VkMemoryLayout] +MemoryLayoutSetList = Union[MemoryLayoutSet, List[MemoryLayoutSet]] + def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int: """ @@ -257,24 +332,622 @@ def valid_texture_memory_layouts( return valid_layouts -def possible_node_memory_layouts( - node: torch.fx.Node, texture_limits: ImageExtents -) -> Set[VkMemoryLayout]: +class TensorRepr: """ - Given a node, determine the set of memory layouts which can be used to represent all - tensors involved in the computation. + This class is a wrapper around a pair of VkStorageType and VkMemoryLayout which + describes how a tensor should be represented in the Vulkan Delegate. """ - assert is_tensor_node(node) - if isinstance(node.meta["val"], FakeTensor): - return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits) - valid_layouts = set() - if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple): - for fake_tensor in node.meta["val"]: - valid_layouts = valid_layouts.union( - valid_texture_memory_layouts(fake_tensor.shape, texture_limits) + + def __init__(self, storage_type: VkStorageType, memory_layout: VkMemoryLayout): + self.storage_type = storage_type + self.memory_layout = memory_layout + + def __str__(self) -> str: + return f"TensorRepr({self.storage_type}, {self.memory_layout})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorRepr): + return NotImplemented + return ( + self.storage_type == other.storage_type + and self.memory_layout == other.memory_layout + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + +class TensorReprList: + """ + This class is a wrapper around a list of TensorRepr instances that automatically + applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single + underlying TensorRepr to be used to represent multiple tensors. + """ + + def __init__(self, tensor_reprs: Union[TensorRepr, List[TensorRepr]]): + self.vals: List[TensorRepr] = ( + tensor_reprs if isinstance(tensor_reprs, list) else [tensor_reprs] + ) + + def __len__(self): + return len(self.vals) + + def __getitem__(self, idx: int) -> TensorRepr: + if idx > 0 and len(self) == 1: + return self.vals[0] + else: + return self.vals[idx] + + def __setitem__(self, idx: int, val: TensorRepr) -> None: + if idx > 0 and len(self) == 1: + self.vals[0] = val + else: + self.vals[idx] = val + + def __str__(self) -> str: + return f"[{', '.join(str(ts) for ts in self.vals)}]" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorReprList): + return NotImplemented + + if len(self) == len(other): + for self_val, other_val in zip(self.vals, other.vals): + if self_val != other_val: + return False + + return True + + return False + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def append(self, val: TensorRepr) -> None: + self.vals.append(val) + + def storage_type(self, idx: int = 0) -> VkStorageType: + return self.vals[idx].storage_type + + def memory_layout(self, idx: int = 0) -> VkMemoryLayout: + return self.vals[idx].memory_layout + + +class TensorRepSet: + """ + This class describes the possible set of representations (i.e. TensorRepr) that may + be used to represent a tensor. This set is determined by the implementation of the + operator that the tensor participates in as well as the texture extents of the GPU. + """ + + def __init__( + self, + buffer_memory_layouts: Set[VkMemoryLayout], + texture_memory_layouts: Set[VkMemoryLayout], + ): + self.valid_buffer_layouts = buffer_memory_layouts + self.valid_texture_layouts = texture_memory_layouts + + def __str__(self) -> str: + buffer_layouts = ", ".join(layout.name for layout in self.valid_buffer_layouts) + texture_layouts = ", ".join( + layout.name for layout in self.valid_texture_layouts + ) + return f"TensorRepSet(Buffer Layouts: [{buffer_layouts}], Texture Layouts: [{texture_layouts}])" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, TensorRepSet): + return NotImplemented + return ( + self.valid_buffer_layouts == other.valid_buffer_layouts + and self.valid_texture_layouts == other.valid_texture_layouts + ) + + def __ne__(self, other: object) -> bool: + return not self.__eq__(other) + + def is_empty(self) -> bool: + """ + A TensorRepSet is "empty" if there are no valid representations of the tensor. + """ + return ( + len(self.valid_buffer_layouts) == 0 and len(self.valid_texture_layouts) == 0 + ) + + def make_intersect(self, other: "TensorRepSet") -> "TensorRepSet": + """ + Merge this TensorRepr with another TensorRepr, returning a new TensorRepr + with the intersection of the two. + """ + return TensorRepSet( + self.valid_buffer_layouts & other.valid_buffer_layouts, + self.valid_texture_layouts & other.valid_texture_layouts, + ) + + def is_compatible(self, storage: TensorRepr) -> bool: + """ + Check if this TensorRepr is compatible with the given TensorRepSet. + """ + if storage.storage_type == VkStorageType.BUFFER: + return storage.memory_layout in self.valid_buffer_layouts + elif storage.storage_type == VkStorageType.TEXTURE_3D: + return storage.memory_layout in self.valid_texture_layouts + else: + raise RuntimeError(f"Unsupported storage type {storage.storage_type}") + + def any_in_common(self, other: "TensorRepSet") -> bool: + """ + Check if this TensorRepr has any representations in common with another + TensorRepr. + """ + return ( + len(self.valid_buffer_layouts & other.valid_buffer_layouts) > 0 + or len(self.valid_texture_layouts & other.valid_texture_layouts) > 0 + ) + + def texture_is_valid(self): + return len(self.valid_texture_layouts) > 0 + + def buffer_is_valid(self): + return len(self.valid_buffer_layouts) > 0 + + def first_valid_buffer_layout(self): + return list(self.valid_buffer_layouts)[0] + + def first_valid_texture_layout(self): + return list(self.valid_texture_layouts)[0] + + def make_tensor_repr(self) -> TensorRepr: + """ + Pick a representation (i.e. TensorRepr) from the set of possible representations. + If there are multiple valid representations, then: + 1. Prefer texture storage over buffer storage + 2. Pick the first available memory layout. + """ + if self.is_empty(): + # An empty repset typically means that it is associated with a weight tensor + # or non tensor argument. In this case, just return default storage and + # layout as placeholder. + return TensorRepr( + VkStorageType.DEFAULT_STORAGE, VkMemoryLayout.DEFAULT_LAYOUT ) - return valid_layouts + if self.texture_is_valid(): + return TensorRepr( + VkStorageType.TEXTURE_3D, self.first_valid_texture_layout() + ) + + else: + return TensorRepr(VkStorageType.BUFFER, self.first_valid_buffer_layout()) + + def is_constrained(self) -> bool: + """ + A "constrained" RepSet is one that has either: + 1. A single valid texture memory layout, and no valid buffer memory layouts + 2. No valid texture memory layouts, and a single valid buffer memory layout + 3. Is empty + + In this case, it is unambiguous which representation should be used for the + tensor. + """ + if self.is_empty(): + return True + elif ( + len(self.valid_texture_layouts) == 1 and len(self.valid_buffer_layouts) == 0 + ): + return True + elif ( + len(self.valid_texture_layouts) == 0 and len(self.valid_buffer_layouts) == 1 + ): + return True + else: + return False + + def is_ambiguous(self) -> bool: + """ + An "ambiguous" RepSet is one that is not constrained. + """ + return not self.is_constrained() + + +def make_tensor_repset(tensor_repr: TensorRepr) -> TensorRepSet: + """ + Given a TensorRepr, return a TensorRepSet that contains only that TensorRepr + """ + if tensor_repr.storage_type == VkStorageType.BUFFER: + return TensorRepSet({tensor_repr.memory_layout}, set()) + elif tensor_repr.storage_type == VkStorageType.TEXTURE_3D: + return TensorRepSet(set(), {tensor_repr.memory_layout}) + else: + raise RuntimeError(f"Unsupported storage type {tensor_repr.storage_type}") + + +def make_filtered_tensor_repset( + tensor_val: FakeTensor, + tensor_repset: TensorRepSet, + texture_limits: ImageExtents, +) -> TensorRepSet: + """ + `tensor_val` represents an actual tensor participating in some operator computation. + + `tensor_repset` represents the set of valid tensor representations that may be used + for that tensor that is supported by the op implementation. + + `texture_limits` represents the maximum texture sizes that is supported by the GPU. + + Given the above, return a new TensorRepSet that contains only texture layouts that + can be used to produce a valid image texture for the given tensor (i.e. fits within + texture limits). + """ + valid_texture_layouts = set() + for memory_layout in tensor_repset.valid_texture_layouts: + extents = required_image_extents(tensor_val.shape, memory_layout) + if extents_are_valid(extents, texture_limits): + valid_texture_layouts.add(memory_layout) + + # High dimensional tensors are currently not supported + if len(tensor_val.shape) > 4: + return NO_STORAGE + + # Bool tensors are currently not supported + if tensor_val.dtype == torch.bool: + return NO_STORAGE + + return TensorRepSet(tensor_repset.valid_buffer_layouts, valid_texture_layouts) + + +## Convenience TensorRepSet definitions + +CONTIGUOUS_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_WIDTH_PACKED} +) +CONTIGUOUS_BUFFER = TensorRepSet({VkMemoryLayout.TENSOR_WIDTH_PACKED}, set()) + +WIDTH_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_WIDTH_PACKED}) +CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) + +ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) + +ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts) +NO_STORAGE = TensorRepSet(set(), set()) + + +class TensorRepSetList: + """ + This class is a wrapper around a list of TensorRepSet instances that automatically + applies a "broadcasting" mechanism. The broadcasting mechanism allows for a single + underlying TensorRepSet to be used for multiple tensors. + """ + + def __init__( + self, + tensor_repsets: Union[TensorRepSet, List[TensorRepSet]], + ): + self.vals: List[TensorRepSet] = ( + tensor_repsets if isinstance(tensor_repsets, list) else [tensor_repsets] + ) + + def __len__(self): + return len(self.vals) + + def __getitem__(self, idx: int) -> TensorRepSet: + if idx > 0 and len(self) == 1: + return self.vals[0] + else: + return self.vals[idx] + + def __setitem__(self, idx: int, val: TensorRepSet) -> None: + if idx > 0 and len(self.vals) == 1: + self.vals[0] = val + else: + self.vals[idx] = val + + def __str__(self) -> str: + return f"[{', '.join(str(ts) for ts in self.vals)}]" + + def append(self, val: TensorRepSet) -> None: + return self.vals.append(val) + + def any_is_empty(self) -> bool: + if len(self.vals) == 0: + return True + + return any(tensor_repr.is_empty() for tensor_repr in self.vals) + + +class OpRepSets: + """ + This class is responsible for representing and managing the set of valid tensor + representations that may be used for all input and output tensors of an operator. + It is also responsible for maintaining synchronization rules between tensors + participating in the computation. + + Currently, three synchronization rules exist: + 1. All input tensors must use the same representation (e.g. binary ops) + 2. The "primary" input and output tensors must use the same representation + (e.g. group norm; the output is a tuple of out, mean, rstd; out must be the same + representation as the first input x, but mean and rstd may use different + representations as out) + 3. All output tensors must use the same representation (e.g. choose qparams) + + Note that "primary" input and output tensor refers to the first non-weight input + tensor and the first output tensor. Note that Some operators (such as arange) do not + have any tensor inputs. + + Currently, the above three synchronization rules are sufficient to describe the + representation requirements of all ET-VK operators. + + This class also provides utilities to constrain the repsets; when applying the + constraints, the synchronization rules will be maintained. + """ + + def __init__( # noqa: C901 + self, + inputs_repsets: TensorRepSetList, + outputs_repsets: TensorRepSetList, + op_node: torch.fx.Node, + texture_limits: ImageExtents, + ): + self.op_node = op_node + + # inputs_repset_list is received from the operator registration. If a different + # repset is defined for each input tensor, then assume that the input tensor + # representations do not need to be synchronized. + if len(inputs_repsets) > 1: + self.sync_args_repr = False + # Otherwise, default to True + else: + self.sync_args_repr = True + + # outputs_repset_list is received from the operator registration. If a different + # repset is defined for each output tensor, then assume that the output tensor + # representations do not need to be synchronized. + if len(outputs_repsets) > 1: + self.sync_outs_repr = False + else: + self.sync_outs_repr = True + + # Try to determine the index of the "primary" argument, i.e. the first non + # constant tensor argument. For the vast majority of operators with tensor + # arguments, this will be the first argument. + self.primary_arg_idx: Optional[int] = None + for i, arg_node in enumerate(self.op_node.args): + arg_node_repset = inputs_repsets[i] + if not is_tensor_arg_node(arg_node): + continue + if arg_node_repset is None: + continue + if arg_node_repset.is_empty(): + continue + + self.primary_arg_idx = i + break + + # If the repset of the primary input and the primary output are the same, then + # assume they need to be the same. + self.sync_primary_io_repr = self.primary_arg_idx is not None + if self.primary_arg_idx is not None: + if inputs_repsets[self.primary_arg_idx] != outputs_repsets[0]: + self.sync_primary_io_repr = False + + # Now, go through the arguments of the operator and create a filtered repset + # for each based on the actual tensor value. + args_repset_list = TensorRepSetList([]) + common_arg_repset = ANY_STORAGE + for i, arg_node in enumerate(op_node.args): + arg_repset = inputs_repsets[i] + + # Use ANY_STORAGE for non-tensor nodes so they don't cause the op repsets to + # appear empty + if not is_tensor_arg_node(arg_node): + args_repset_list.append(ANY_STORAGE) + # NO_STORAGE is used to denote that an input is either a non tensor arg or + # a weight tensor that is not prepacked. Similar to the above, use + # ANY_STORAGE in this case. + elif arg_repset.is_empty(): + args_repset_list.append(ANY_STORAGE) + else: + assert not arg_repset.is_empty() + + arg_repset = self.make_valid_tensor_repset_for_arg( + arg_repset, arg_node, texture_limits + ) + + args_repset_list.append(arg_repset) + common_arg_repset = common_arg_repset.make_intersect(arg_repset) + + # Repeat for output tensors. + outs_repset_list = TensorRepSetList([]) + common_out_repset = ANY_STORAGE + if num_tensors_in_node(op_node) == 1: + common_out_repset = make_filtered_tensor_repset( + op_node.meta["val"], outputs_repsets[0], texture_limits + ) + outs_repset_list.append(common_out_repset) + # Multiple output tensors + else: + for i, val in enumerate(op_node.meta["val"]): + assert isinstance(val, FakeTensor) + out_repset = make_filtered_tensor_repset( + val, outputs_repsets[i], texture_limits + ) + + outs_repset_list.append(out_repset) + common_out_repset = common_out_repset.make_intersect(out_repset) + + # Apply synchronization rules; if either all inputs/outputs must use the same + # representation, then only use a single underlying repset. + if self.sync_args_repr: + args_repset_list = TensorRepSetList([common_arg_repset]) + + if self.sync_outs_repr: + outs_repset_list = TensorRepSetList([common_out_repset]) + + # Finally, apply synchronization rules that sync inputs and outputs. If input + # or output repsets are updated, then maintain synchronization rules. + if self.sync_primary_io_repr: + assert self.primary_arg_idx is not None + + primary_in_repset = args_repset_list[self.primary_arg_idx] + primary_out_repset = outs_repset_list[0] + + primary_repset = primary_in_repset.make_intersect(primary_out_repset) + + if self.sync_args_repr: + args_repset_list = TensorRepSetList([primary_repset]) + else: + assert self.primary_arg_idx is not None + args_repset_list[self.primary_arg_idx] = primary_repset + + if self.sync_outs_repr: + outs_repset_list = TensorRepSetList([primary_repset]) + else: + assert self.primary_arg_idx is not None + outs_repset_list[0] = primary_repset + + # Save the resulting repsets + self.args_repset_list = args_repset_list + self.outs_repset_list = outs_repset_list + + # Check that synchronization rules are respected. + self.assert_sync_contraints() + + def __str__(self) -> str: + return f"OpRepSets(ins={self.args_repset_list}, outs={self.outs_repset_list})" + + def make_valid_tensor_repset_for_node_list_arg( + self, + arg_repsets: TensorRepSet, + arg_node: List[torch.fx.Node], + texture_limits: ImageExtents, + ) -> TensorRepSet: + """ + Wrapper around make_filtered_tensor_repset for a list of nodes. This will happen + for the cat operator, where the first argument is a list of nodes. + """ + # For variable length args, assume that they all need to use the same representation + # only one repset should be defined + common_tensor_repsets = arg_repsets + + for n in arg_node: + assert isinstance(n, torch.fx.Node) + common_tensor_repsets = common_tensor_repsets.make_intersect( + make_filtered_tensor_repset( + n.meta["val"], common_tensor_repsets, texture_limits + ) + ) + + return common_tensor_repsets + + def make_valid_tensor_repset_for_arg( + self, arg_repsets: TensorRepSet, arg_node: Any, texture_limits: ImageExtents + ) -> TensorRepSet: + """ + Helper function to call make_filtered_tensor_repset + """ + if isinstance(arg_node, torch.fx.Node) and is_single_tensor_node(arg_node): + return make_filtered_tensor_repset( + arg_node.meta["val"], arg_repsets, texture_limits + ) + elif isinstance(arg_node, list) and all( + is_single_tensor_node(n) for n in arg_node + ): + return self.make_valid_tensor_repset_for_node_list_arg( + arg_repsets, arg_node, texture_limits + ) + # Special case for getitem; return the repset of the particular val in the + # list of tensors that is being extracted. + elif ( + self.op_node.target == operator.getitem and arg_node == self.op_node.args[0] + ): + idx = self.op_node.args[1] + assert isinstance(idx, int) + return make_filtered_tensor_repset( + arg_node.meta["val"][idx], arg_repsets, texture_limits + ) + + raise NotImplementedError(f"Unhandled node type {arg_node}") + + def assert_sync_contraints(self) -> None: + if self.sync_args_repr: + assert len(self.args_repset_list) == 1 + + if self.sync_outs_repr: + assert len(self.outs_repset_list) == 1 + + if self.sync_primary_io_repr: + assert ( + self.args_repset_list[self.primary_arg_idx] == self.outs_repset_list[0] + ) + + def any_is_empty(self) -> bool: + return ( + self.args_repset_list.any_is_empty() or self.outs_repset_list.any_is_empty() + ) + + def get_arg_repset(self, i: int): + return self.args_repset_list[i] + + def get_out_repset(self, i: int): + return self.outs_repset_list[i] + + def try_constrain_with_arg_repset( + self, arg_i: int, source_repset: TensorRepSet + ) -> bool: + """ + Attempt to constrain the repsets of the tensors participating in this operator + based on an "existing" repset of an argument. The existing repset can have two + sources: + * A representation may have been determined for the argument already from a + prior operator + * The output repset of the operator which produces the argument + + If the existing repset of the argument is compatible with the current operator, + then constrain the repsets of this operator and apply synchronization rules. + + This process tries to minimize the number of transition nodes that will need to + be inserted by tag_memory_meta_pass.py by maintaining existing representations + for as long as possible. + """ + arg_current_repset = self.args_repset_list[arg_i] + + if arg_current_repset == source_repset: + return False + + if not arg_current_repset.any_in_common(source_repset): + return False + + if self.sync_primary_io_repr: + if not self.get_out_repset(0).any_in_common(source_repset): + return False + + # If this point is reached, then it is possible to constrain + self.args_repset_list[arg_i] = arg_current_repset.make_intersect(source_repset) + if self.sync_primary_io_repr and ( + arg_i == self.primary_arg_idx or self.sync_args_repr + ): + self.outs_repset_list[0] = arg_current_repset.make_intersect(source_repset) + + self.assert_sync_contraints() + return True + + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: + """ + For each tensor participating in the op, pick a representation for it among the + possible represetntation sets. + """ + args_repr_list = TensorReprList([]) + outs_repr_list = TensorReprList([]) + + for i in range(len(self.op_node.args)): + arg_repset = self.args_repset_list[i] + args_repr_list.append(arg_repset.make_tensor_repr()) + + for i in range(num_tensors_in_node(self.op_node)): + out_repset = self.outs_repset_list[i] + outs_repr_list.append(out_repset.make_tensor_repr()) + + return args_repr_list, outs_repr_list ## @@ -282,6 +955,10 @@ def possible_node_memory_layouts( ## +def has_node_spec_attr(node: torch.fx.Node, attr: str) -> bool: + return "spec" in node.meta and hasattr(node.meta["spec"], attr) + + def set_node_spec_attr(node: torch.fx.Node, attr: str, value): assert "spec" in node.meta spec = node.meta["spec"] @@ -327,6 +1004,30 @@ def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]: return get_node_spec_attr(node, "vk_memory_layout") +def has_node_repr(node) -> bool: + if isinstance(node, (list, tuple)): + return all(has_node_spec_attr(n, "etvk_node_repr") for n in node) + else: + return has_node_spec_attr(node, "etvk_node_repr") + + +def set_node_repr(node: torch.fx.Node, node_repr: Union[TensorRepr, TensorReprList]): + if isinstance(node_repr, TensorReprList): + # Convert to a regular list so taht `set_node_spec_attr` can attach each entry + # to a separate TensorSpec + node_repr_list = [node_repr[i] for i in range(num_tensors_in_node(node))] + set_node_spec_attr(node, "etvk_node_repr", node_repr_list) + else: + set_node_spec_attr(node, "etvk_node_repr", node_repr) + + +def get_node_repr(node) -> Union[TensorRepr, TensorReprList]: + if isinstance(node, (list, tuple)): + raise NotImplementedError("get_node_repr not implemented for list of nodes") + else: + return get_node_spec_attr(node, "etvk_node_repr", False) + + ## ## Misc ##