From ef3e85a6a3ddcacc5fa97f61273eacad01f82587 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 6 Nov 2025 16:25:07 -0800 Subject: [PATCH 1/3] [ET-VK] Implementation of to_dim_order_copy Pull Request resolved: https://github.com/pytorch/executorch/pull/15619 Title says it all! Previously, to_dim_order_copy was handled by removing the op. However, this is not possible if the op is modifying the dtype of the original tensor, so these instances of the op would be skipped by the partitioner. This diff adds an implementation dtype conversion, which allows to_dim_order_copy to be lowered. ghstack-source-id: 321555048 @exported-using-ghexport Differential Revision: [D86340341](https://our.internmc.facebook.com/intern/diff/D86340341/) --- .../vulkan/_passes/remove_redundant_ops.py | 36 +++++---- backends/vulkan/op_registry.py | 26 +----- .../runtime/graph/ops/glsl/view_buffer.glsl | 21 +++-- .../graph/ops/glsl/view_convert_buffer.glsl | 54 +++++++++++++ .../graph/ops/glsl/view_convert_buffer.yaml | 22 ++++++ .../runtime/graph/ops/impl/Unsqueeze.cpp | 12 +++ .../vulkan/runtime/graph/ops/impl/View.cpp | 79 +++++++++++++++++++ backends/vulkan/runtime/graph/ops/impl/View.h | 13 +++ backends/vulkan/test/test_vulkan_delegate.py | 12 --- 9 files changed, 214 insertions(+), 61 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index da127f72528..e487491dfbb 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -7,17 +7,12 @@ # pyre-unsafe import operator - from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa - import executorch.backends.vulkan.utils as utils - import torch - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 36a8ee4c3b1..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -67,6 +67,18 @@ void resize_unsqueeze_node( std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions for (auto dim : dims_vec) { int64_t d = dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..5e2c898573a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + if (graph.dtype_of(in) == graph.dtype_of(out)) { + return add_view_copy_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); + } + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f92cea64767..f38c510a8b1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -11,20 +11,14 @@ from typing import Tuple import executorch.backends.vulkan.test.utils as test_utils - import torch - from executorch.backends.transforms.convert_dtype_pass import I64toI32 - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -36,11 +30,8 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import Dim, export, ExportedProgram - from torchao.quantization.granularity import PerGroup - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer import Quantizer from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass @@ -69,9 +60,6 @@ def lower_module( edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], ) From fffeb4cb1080d5e22f00676687bc08ca2516dcc5 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 6 Nov 2025 16:25:10 -0800 Subject: [PATCH 2/3] [ET-VK][ez] Ensure that attn_weight buffers do not exceed GPU buffer numel limit Pull Request resolved: https://github.com/pytorch/executorch/pull/15651 Title says it all! To give a concrete example, Llama3.2-1B-Instruct will have attn weights with size `{1, 32, max_seq_len, max_context_len}`. Usually `max_seq_len == max_context_len`, and if `max_context_len = 2048` Then the attention weight tensors will have sizes `{1, 32, 2048, 2048}` which will contain 134217728 elements. The `maxStorageBufferRange` for Adreno 750 is also 134217728 (2^27), so using context length of 2048 will produce incorrect results on Adreno 750. In practice, it is unlikely that the prompt sequence length will be equal to the context length, so the solution is to adjust down the `max_seq_len` dim of the attention weight tensors to ensure that the GPU buffer numel limit is not hit. ghstack-source-id: 321555042 @exported-using-ghexport Differential Revision: [D86443407](https://our.internmc.facebook.com/intern/diff/D86443407/) --- backends/vulkan/runtime/graph/ComputeGraph.h | 4 +++ .../vulkan/runtime/graph/ops/impl/SDPA.cpp | 35 +++++++++++++++---- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index f7de7e183de..b61bd4a51c0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -639,6 +639,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 4eed8b82834..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl( (void)sequence_len; - utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = From 962c01e30c7fe5b6c2126b38033f5a5d2960f429 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 6 Nov 2025 16:25:13 -0800 Subject: [PATCH 3/3] [ET-VK] buffer implementation of rotary positional embeddings Pull Request resolved: https://github.com/pytorch/executorch/pull/15620 Title says it all! ghstack-source-id: 321555043 @exported-using-ghexport Differential Revision: [D86340338](https://our.internmc.facebook.com/intern/diff/D86340338/) --- .../vulkan/_passes/tag_memory_meta_pass.py | 16 +-- backends/vulkan/op_registry.py | 2 +- .../graph/ops/glsl/embedding_texture.glsl | 6 +- .../runtime/graph/ops/glsl/indexing.glslh | 50 +++++++-- .../graph/ops/glsl/rotary_embedding.glsl | 103 +++++++++++++----- .../graph/ops/glsl/rotary_embedding.yaml | 3 + .../graph/ops/impl/RotaryEmbedding.cpp | 21 +++- backends/vulkan/vulkan_preprocess.py | 11 +- 8 files changed, 148 insertions(+), 64 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8ed71aa1dae..15449b98f6f 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,22 +6,16 @@ import logging import operator - from typing import Any import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec @@ -130,15 +124,17 @@ def __init__( texture_limits: utils.ImageExtents, default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + force_fp16: bool = False, ): super().__init__() self.default_storage: VkStorageType = default_storage_type self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits + self.force_fp16 = force_fp16 # Magic number to limit "lookahead" when tracing through users of an operator # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = 20 + self.max_trace_search_depth = None def is_valid_op_node(self, node: Any) -> bool: """ @@ -361,6 +357,12 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No 2. Then, try to trace through the users of the argument to find a representation that can be used for as long as possible without needing a transition. """ + # If forcing fp16, then try to use texture storage whenever possible. This is + # a temporary stopgap measure until all buffer implementations properly account + # for potential overflow of fp16 representation range when doing math in fp16. + if self.force_fp16: + op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index e487491dfbb..ef41060272c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -687,7 +687,7 @@ def register_sdpa_ops(): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl index b064d8a3295..9a6295a8094 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) { indices_tidx.data.xyz = out_tidx.data.yzw; indices_tidx.data.w = 0; - TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple( - indices_tidx, indices); + TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple( + indices, indices_tidx); const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0); return in_texel[elem_pos.comp]; @@ -61,7 +61,7 @@ void main() { return; } - TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); const int embedding_idx = load_embedding_idx(out_tidx); const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x); diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 0e30faa5d66..38016547d19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -147,6 +147,20 @@ struct TensorIndex4D { ivec4 data; }; +TensorIndex4D zero_tensor4d_idx() { + TensorIndex4D tidx; + tidx.data = ivec4(0); + return tidx; +} + +bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes[0])); +} + +bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes)); +} + // // TextureElementIndex // @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) { tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1); } -TensorIndex4D zero_tensor4d_idx() { - TensorIndex4D tidx; - tidx.data = ivec4(0); - return tidx; -} - // Does not account for axis mapping or batches -TensorIndex4D texture_pos_to_tensor_idx_simple( - const ivec3 pos, const TextureMetadata meta) { +TensorIndex4D texture_pos_to_tensor4d_idx_simple( + const TextureMetadata meta, const ivec3 pos) { TensorIndex4D tidx; tidx.data.xyz = pos; tidx.data.w = 0; @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple( } // Does not account for axis mapping or batches -TextureElementIndex tensor_idx_to_texture_element_idx_simple( - const TensorIndex4D tidx, const TextureMetadata meta) { +ivec3 tensor4d_idx_to_texel_pos_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + ivec3 texel_pos; + + const int packed_dim_idx = tidx.data[meta.packed_dim]; + + texel_pos = tidx.data.xyz; + texel_pos[meta.packed_dim] = div_4(packed_dim_idx); + return texel_pos; +} + +// Does not account for axis mapping or batches +TextureElementIndex tensor4d_idx_to_texture_element_idx_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { const int packed_dim_idx = tidx.data[meta.packed_dim]; TextureElementIndex tex_idx; tex_idx.pos = tidx.data.xyz; @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple( return tex_idx; } +uint tensor4d_idx_to_linear_idx( + const BufferMetadata meta, + const TensorIndex4D tidx) { + uint lin_idx = 0; + for (int d = 0; d < 4; ++d) { + lin_idx += meta.strides[0][d] * tidx.data[d]; + } + return lin_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl index 30375728921..155eda467c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -13,23 +13,29 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} ${define_required_extensions(DTYPE)} +${define_active_storage_type(STORAGE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "xqout_limits")} -${layout_declare_ubo(B, "ivec3", "xkout_limits")} +#include "indexing.glslh" -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} -layout(constant_id = 3) const int packed_dim = 0; +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} -#include "indexing_utils.h" +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * This shader computes rotary positional embeddings which are used in the Llama @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0; * 1. xq (batch_size, sequence_len, num_heads, head_dim) * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) * 3. freqs_cos (sequence_len, head_dim / 2) - * 4. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_sin (sequence_len, head_dim / 2) * * Two output tensors are produced, with the same shapes as xq and xk * respectively. @@ -66,23 +72,43 @@ void main() { // Each thread will write to two output locations to maximize data re-use. // One texel loaded from the freqs_cos/freqs_sin tensors can be used to // calculate two output texels. - const ivec3 x_pos_1 = ivec3( - gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); - const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + TensorIndex4D out_tidx_1 = zero_tensor4d_idx(); + out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8; + out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz); + + TensorIndex4D out_tidx_2 = out_tidx_1; + out_tidx_2.data.x += 4; - if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + if (out_of_bounds(out_tidx_2, xqout)) { return; } - const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + freqs_tidx.data.y = out_tidx_1.data.z; - VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); - VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); +#ifdef USING_BUFFER + const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx)); + VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi]; + VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi]; - // Compute xqout + uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1)); + uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2)); + VEC4_T x_tex_1 = t_xq[x_texel_bufi_1]; + VEC4_T x_tex_2 = t_xq[x_texel_bufi_2]; + +#else // USING_TEXTURE + const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx); + VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0); + VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0); - VEC4_T x_tex_1 = load_texel(xq, x_pos_1); - VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1); + const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2); + VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0); + VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0); +#endif + + // Compute xqout // Separate into even and odd elements VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); @@ -94,20 +120,34 @@ void main() { VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xqout, x_pos_1, xout_tex_1); - write_texel(xqout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xqout[x_texel_bufi_1] = xout_tex_1; + t_xqout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xqout, x_pos_1, xout_tex_1); + imageStore(t_xqout, x_pos_2, xout_tex_2); +#endif // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout // may have a larger height dim than xk and xkout. Only compute xkout if this // invocation is still within bounds. - if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + if (out_of_bounds(out_tidx_2, xkout)) { return; } // Compute xkout - x_tex_1 = load_texel(xk, x_pos_1); - x_tex_2 = load_texel(xk, x_pos_2); +#ifdef USING_BUFFER + x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1)); + x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2)); + + x_tex_1 = t_xk[x_texel_bufi_1]; + x_tex_2 = t_xk[x_texel_bufi_2]; + +#else // USING_TEXTURE + x_tex_1 = texelFetch(t_xk, x_pos_1, 0); + x_tex_2 = texelFetch(t_xk, x_pos_2, 0); +#endif x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); @@ -118,6 +158,11 @@ void main() { xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xkout, x_pos_1, xout_tex_1); - write_texel(xkout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xkout[x_texel_bufi_1] = xout_tex_1; + t_xkout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xkout, x_pos_1, xout_tex_1); + imageStore(t_xkout, x_pos_2, xout_tex_2); +#endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml index a81fd564d10..ba8aa400958 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -3,6 +3,9 @@ rotary_embedding: DTYPE: float STORAGE: texture3d generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index fcc8fe4b265..e1914f350b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size( const ValueRef xq_out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); - global_wg_size[0] /= 2; + // Head dim texel size + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + // Divide by 2 since each invocation computes 2 output locations + const uint32_t D8 = utils::div_up(D4, uint32_t(2)); - return global_wg_size; + // Number of query heads + const uint32_t QH = graph->size_at(-2, xq_out); + // Input tokens sequence length + const uint32_t S = graph->size_at(-3, xq_out); + + return {D8, QH, S}; } void add_rotary_embedding_node( @@ -73,8 +80,14 @@ void add_rotary_embedding_node( VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); std::string kernel_name = "rotary_embedding"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos)}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -84,7 +97,7 @@ void add_rotary_embedding_node( {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, // Parameter buffers - {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + param_ubos, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 57863703498..81ee67a596c 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -7,11 +7,9 @@ # pyre-strict from functools import partial - from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils - from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform @@ -29,7 +27,6 @@ ) from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -39,7 +36,6 @@ serialize_vulkan_graph, ) from executorch.backends.xnnpack._passes import FuseBatchNormPass - from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -47,18 +43,12 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase - from executorch.exir.passes import MemoryPlanningPass, SpecPropPass - from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - from executorch.exir.program._program import _transform - from torch._export.verifier import Verifier - from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) @@ -209,6 +199,7 @@ def preprocess( # noqa: C901 texture_limits, default_storage_type=default_storage_type, default_memory_layout=default_memory_layout, + force_fp16=force_fp16, ), ], )