diff --git a/backends/vulkan/_passes/remove_redundant_ops.py b/backends/vulkan/_passes/remove_redundant_ops.py index 8e602dd17b4..25bdd34de70 100644 --- a/backends/vulkan/_passes/remove_redundant_ops.py +++ b/backends/vulkan/_passes/remove_redundant_ops.py @@ -31,35 +31,37 @@ class RemoveRedundantOpsTransform(ExportPass): exir_ops.edge.aten.lift_fresh_copy.default, exir_ops.edge.dim_order_ops._to_dim_order_copy.default, exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten.expand_copy.default, } def __init__(self) -> None: super(RemoveRedundantOpsTransform, self).__init__() def _should_remove(self, node: torch.fx.Node) -> bool: - if node.target in self.redundant_ops: - return True - - # Only remove to_copy if dtype does not change. Otherwise, memory format changes - # will be handled internally by the backend. - if ( - node.target == exir_ops.edge.aten._to_copy.default - or node.target == torch.ops.aten._to_copy.default - ): - src_dtype = node.meta["val"].dtype - # pyre-ignore - dst_dtype = node.args[0].meta["val"].dtype - return src_dtype == dst_dtype - - return False + if node.target not in self.redundant_ops: + return False + + orig_node = node.args[0] + assert isinstance(orig_node, torch.fx.Node) + + src_dtype = orig_node.meta["val"].dtype + dst_dtype = node.meta["val"].dtype + + # Do not remove if the op is converting the dtype. + if src_dtype != dst_dtype: + return False + + src_shape = orig_node.meta["val"].shape + dst_shape = node.meta["val"].shape + + return src_shape == dst_shape def _remove(self, graph_module: torch.fx.GraphModule) -> None: for node in graph_module.graph.nodes: if not self._should_remove(node): continue - with graph_module.graph.inserting_after(node): - node.replace_all_uses_with(node.args[0]) + node.replace_all_uses_with(node.args[0]) graph_module.graph.eliminate_dead_code() diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 8ed71aa1dae..15449b98f6f 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,22 +6,16 @@ import logging import operator - from typing import Any import executorch.backends.vulkan.utils as utils - import torch - from executorch.backends.vulkan.op_registry import get_op_features, has_impl, OpFeatures - from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, VkStorageType, ) - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec @@ -130,15 +124,17 @@ def __init__( texture_limits: utils.ImageExtents, default_storage_type: VkStorageType = VkStorageType.TEXTURE_3D, default_memory_layout: VkMemoryLayout = VkMemoryLayout.TENSOR_WIDTH_PACKED, + force_fp16: bool = False, ): super().__init__() self.default_storage: VkStorageType = default_storage_type self.default_layout: VkMemoryLayout = default_memory_layout self.texture_limits = texture_limits + self.force_fp16 = force_fp16 # Magic number to limit "lookahead" when tracing through users of an operator # to constrain the representation of its arguments/outputs. - self.max_trace_search_depth = 20 + self.max_trace_search_depth = None def is_valid_op_node(self, node: Any) -> bool: """ @@ -361,6 +357,12 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No 2. Then, try to trace through the users of the argument to find a representation that can be used for as long as possible without needing a transition. """ + # If forcing fp16, then try to use texture storage whenever possible. This is + # a temporary stopgap measure until all buffer implementations properly account + # for potential overflow of fp16 representation range when doing math in fp16. + if self.force_fp16: + op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) + arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index da127f72528..ef41060272c 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -7,17 +7,12 @@ # pyre-unsafe import operator - from typing import Any, Callable, Dict, List, Optional, Union import executorch.backends.vulkan.custom_ops_lib # noqa - import executorch.backends.vulkan.utils as utils - import torch - from executorch.exir.dialects._ops import ops as exir_ops - from executorch.exir.dialects.edge._ops import EdgeOpOverload from torch._subclasses.fake_tensor import FakeTensor @@ -129,6 +124,7 @@ def update_features_impl(op: OpKey): # Symbolic integer ops torch.ops.aten.sym_size.int, operator.add, + operator.sub, operator.lt, operator.gt, operator.ge, @@ -297,27 +293,9 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: @update_features(exir_ops.edge.dim_order_ops._to_dim_order_copy.default) def register_to_copy_dim_order_op(): - # Currently there is no "real" implementation for to_dim_order_copy, but it can be - # removed as long as the operator is not changing the dtype, i.e. the operator call - # is modifying the dim order only. Therefore, check that the input and output dtypes - # are the same, if so the operator is safe to remove. - def check_dim_order_copy_node(node: torch.fx.Node) -> bool: - in_arg = node.args[0] - if not isinstance(in_arg, torch.fx.Node): - return False - - in_tensor = in_arg.meta.get("val", None) - out_tensor = node.meta.get("val", None) - - if in_tensor.dtype != out_tensor.dtype: - return False - - return True - return OpFeatures( - inputs_storage=utils.ANY_STORAGE, + inputs_storage=utils.ANY_BUFFER, supports_resize=True, - are_node_inputs_supported_fn=check_dim_order_copy_node, ) @@ -709,7 +687,7 @@ def register_sdpa_ops(): @update_features(exir_ops.edge.et_vk.apply_rotary_emb.default) def register_rotary_emb_op(): return OpFeatures( - inputs_storage=utils.WIDTH_PACKED_TEXTURE, + inputs_storage=utils.CONTIGUOUS_ANY, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index f7de7e183de..b61bd4a51c0 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -639,6 +639,10 @@ class ComputeGraph final { bool device_name_contains(const char* substr); + int64_t max_buffer_numel() { + return static_cast(context_->adapter_ptr()->max_buffer_numel()); + } + // // Graph Building // diff --git a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl index b064d8a3295..9a6295a8094 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl @@ -38,8 +38,8 @@ int load_embedding_idx(const TensorIndex4D out_tidx) { indices_tidx.data.xyz = out_tidx.data.yzw; indices_tidx.data.w = 0; - TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple( - indices_tidx, indices); + TextureElementIndex elem_pos = tensor4d_idx_to_texture_element_idx_simple( + indices, indices_tidx); const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0); return in_texel[elem_pos.comp]; @@ -61,7 +61,7 @@ void main() { return; } - TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp); + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); const int embedding_idx = load_embedding_idx(out_tidx); const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x); diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index 0e30faa5d66..38016547d19 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -147,6 +147,20 @@ struct TensorIndex4D { ivec4 data; }; +TensorIndex4D zero_tensor4d_idx() { + TensorIndex4D tidx; + tidx.data = ivec4(0); + return tidx; +} + +bool out_of_bounds(const TensorIndex4D tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes[0])); +} + +bool out_of_bounds(const TensorIndex4D tidx, const TextureMetadata meta) { + return any(greaterThanEqual(tidx.data, meta.sizes)); +} + // // TextureElementIndex // @@ -245,15 +259,9 @@ void clamp_tensor_idx(const BufferMetadata meta, inout TensorIndex tidx) { tidx.data[1] = min(tidx.data[1], meta.sizes[1] - 1); } -TensorIndex4D zero_tensor4d_idx() { - TensorIndex4D tidx; - tidx.data = ivec4(0); - return tidx; -} - // Does not account for axis mapping or batches -TensorIndex4D texture_pos_to_tensor_idx_simple( - const ivec3 pos, const TextureMetadata meta) { +TensorIndex4D texture_pos_to_tensor4d_idx_simple( + const TextureMetadata meta, const ivec3 pos) { TensorIndex4D tidx; tidx.data.xyz = pos; tidx.data.w = 0; @@ -262,8 +270,20 @@ TensorIndex4D texture_pos_to_tensor_idx_simple( } // Does not account for axis mapping or batches -TextureElementIndex tensor_idx_to_texture_element_idx_simple( - const TensorIndex4D tidx, const TextureMetadata meta) { +ivec3 tensor4d_idx_to_texel_pos_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { + ivec3 texel_pos; + + const int packed_dim_idx = tidx.data[meta.packed_dim]; + + texel_pos = tidx.data.xyz; + texel_pos[meta.packed_dim] = div_4(packed_dim_idx); + return texel_pos; +} + +// Does not account for axis mapping or batches +TextureElementIndex tensor4d_idx_to_texture_element_idx_simple( + const TextureMetadata meta, const TensorIndex4D tidx) { const int packed_dim_idx = tidx.data[meta.packed_dim]; TextureElementIndex tex_idx; tex_idx.pos = tidx.data.xyz; @@ -272,6 +292,16 @@ TextureElementIndex tensor_idx_to_texture_element_idx_simple( return tex_idx; } +uint tensor4d_idx_to_linear_idx( + const BufferMetadata meta, + const TensorIndex4D tidx) { + uint lin_idx = 0; + for (int d = 0; d < 4; ++d) { + lin_idx += meta.strides[0][d] * tidx.data[d]; + } + return lin_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl index 30375728921..155eda467c4 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl @@ -13,23 +13,29 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} ${define_required_extensions(DTYPE)} +${define_active_storage_type(STORAGE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)} -${layout_declare_ubo(B, "ivec3", "xqout_limits")} -${layout_declare_ubo(B, "ivec3", "xkout_limits")} +#include "indexing.glslh" -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_tensor(B, "w", "t_xqout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_xkout", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xq", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_xk", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_cos", DTYPE, STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_freqs_sin", DTYPE, STORAGE, is_scalar_array=False)} -layout(constant_id = 3) const int packed_dim = 0; +$if STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "xqout")} + ${layout_declare_ubo(B, "BufferMetadata", "xkout")} + ${layout_declare_ubo(B, "BufferMetadata", "freqs_cos")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "xqout")} + ${layout_declare_ubo(B, "TextureMetadata", "xkout")} + ${layout_declare_ubo(B, "TextureMetadata", "freqs_cos")} -#include "indexing_utils.h" +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; /* * This shader computes rotary positional embeddings which are used in the Llama @@ -39,7 +45,7 @@ layout(constant_id = 3) const int packed_dim = 0; * 1. xq (batch_size, sequence_len, num_heads, head_dim) * 2. xk (batch_size, sequence_len, num_kv_heads, head_dim) * 3. freqs_cos (sequence_len, head_dim / 2) - * 4. freqs_cos (sequence_len, head_dim / 2) + * 4. freqs_sin (sequence_len, head_dim / 2) * * Two output tensors are produced, with the same shapes as xq and xk * respectively. @@ -66,23 +72,43 @@ void main() { // Each thread will write to two output locations to maximize data re-use. // One texel loaded from the freqs_cos/freqs_sin tensors can be used to // calculate two output texels. - const ivec3 x_pos_1 = ivec3( - gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz); - const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz); + TensorIndex4D out_tidx_1 = zero_tensor4d_idx(); + out_tidx_1.data.x = int(gl_GlobalInvocationID.x) * 8; + out_tidx_1.data.yz = ivec2(gl_GlobalInvocationID.yz); + + TensorIndex4D out_tidx_2 = out_tidx_1; + out_tidx_2.data.x += 4; - if (any(greaterThanEqual(x_pos_2, xqout_limits))) { + if (out_of_bounds(out_tidx_2, xqout)) { return; } - const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0); + TensorIndex4D freqs_tidx = zero_tensor4d_idx(); + freqs_tidx.data.x = int(gl_GlobalInvocationID.x) * 4; + freqs_tidx.data.y = out_tidx_1.data.z; - VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos); - VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos); +#ifdef USING_BUFFER + const uint freqs_texel_bufi = div_4(tensor4d_idx_to_linear_idx(freqs_cos, freqs_tidx)); + VEC4_T cos_tex = t_freqs_cos[freqs_texel_bufi]; + VEC4_T sin_tex = t_freqs_sin[freqs_texel_bufi]; - // Compute xqout + uint x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_1)); + uint x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xqout, out_tidx_2)); + VEC4_T x_tex_1 = t_xq[x_texel_bufi_1]; + VEC4_T x_tex_2 = t_xq[x_texel_bufi_2]; + +#else // USING_TEXTURE + const ivec3 freqs_pos = tensor4d_idx_to_texel_pos_simple(freqs_cos, freqs_tidx); + VEC4_T cos_tex = texelFetch(t_freqs_cos, freqs_pos, 0); + VEC4_T sin_tex = texelFetch(t_freqs_sin, freqs_pos, 0); - VEC4_T x_tex_1 = load_texel(xq, x_pos_1); - VEC4_T x_tex_2 = load_texel(xq, x_pos_2); + const ivec3 x_pos_1 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_1); + const ivec3 x_pos_2 = tensor4d_idx_to_texel_pos_simple(xqout, out_tidx_2); + VEC4_T x_tex_1 = texelFetch(t_xq, x_pos_1, 0); + VEC4_T x_tex_2 = texelFetch(t_xq, x_pos_2, 0); +#endif + + // Compute xqout // Separate into even and odd elements VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); @@ -94,20 +120,34 @@ void main() { VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xqout, x_pos_1, xout_tex_1); - write_texel(xqout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xqout[x_texel_bufi_1] = xout_tex_1; + t_xqout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xqout, x_pos_1, xout_tex_1); + imageStore(t_xqout, x_pos_2, xout_tex_2); +#endif // n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout // may have a larger height dim than xk and xkout. Only compute xkout if this // invocation is still within bounds. - if (any(greaterThanEqual(x_pos_2, xkout_limits))) { + if (out_of_bounds(out_tidx_2, xkout)) { return; } // Compute xkout - x_tex_1 = load_texel(xk, x_pos_1); - x_tex_2 = load_texel(xk, x_pos_2); +#ifdef USING_BUFFER + x_texel_bufi_1 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_1)); + x_texel_bufi_2 = div_4(tensor4d_idx_to_linear_idx(xkout, out_tidx_2)); + + x_tex_1 = t_xk[x_texel_bufi_1]; + x_tex_2 = t_xk[x_texel_bufi_2]; + +#else // USING_TEXTURE + x_tex_1 = texelFetch(t_xk, x_pos_1, 0); + x_tex_2 = texelFetch(t_xk, x_pos_2, 0); +#endif x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz); x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw); @@ -118,6 +158,11 @@ void main() { xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y); xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w); - write_texel(xkout, x_pos_1, xout_tex_1); - write_texel(xkout, x_pos_2, xout_tex_2); +#ifdef USING_BUFFER + t_xkout[x_texel_bufi_1] = xout_tex_1; + t_xkout[x_texel_bufi_2] = xout_tex_2; +#else // USING_TEXTURE + imageStore(t_xkout, x_pos_1, xout_tex_1); + imageStore(t_xkout, x_pos_2, xout_tex_2); +#endif } diff --git a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml index a81fd564d10..ba8aa400958 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.yaml @@ -3,6 +3,9 @@ rotary_embedding: DTYPE: float STORAGE: texture3d generate_variant_forall: + STORAGE: + - VALUE: texture3d + - VALUE: buffer DTYPE: - VALUE: half - VALUE: float diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl index 2c02803a9b1..96b9aa85a1f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/view_buffer.glsl @@ -18,6 +18,8 @@ ${layout_declare_ubo(B, "BufferMetadata", "inp")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + /* * The insight behind the view operation is that the contiguous index of each * tensor element in the input and output tensors are the same. @@ -28,17 +30,20 @@ void main() { return; } - TensorIndex outp_tidx; - linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + uint inp_bufi = outp_bufi; + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); - // To map the output to the input, find the input element that has the same - // contiguous index as the output element. - const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); - TensorIndex inp_tidx; - contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); - const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } t_outp[outp_bufi] = t_inp[inp_bufi]; } diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl new file mode 100644 index 00000000000..a926c9fea11 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.glsl @@ -0,0 +1,54 @@ +#version 450 core + +#define PRECISION ${PRECISION} + +#define IN_T ${buffer_scalar_type(IN_DTYPE)} +#define OUT_T ${buffer_scalar_type(OUT_DTYPE)} + +${define_required_extensions(IN_DTYPE)} +${define_required_extensions(OUT_DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_buffer(B, "w", "t_outp", OUT_DTYPE)} +${layout_declare_buffer(B, "r", "t_inp", IN_DTYPE)} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "all_contiguous", "0")} + +/* + * The insight behind the view_convert operation is that the contiguous index of each + * tensor element in the input and output tensors are the same, but the data types + * may be different and need conversion. + */ +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + uint inp_bufi = outp_bufi; + + if (all_contiguous == 0) { + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // To map the output to the input, find the input element that has the same + // contiguous index as the output element. + const uint contig_idx = tensor_idx_to_contiguous_idx(outp, outp_tidx); + + TensorIndex inp_tidx; + contiguous_idx_to_tensor_idx(inp, contig_idx, inp_tidx); + + inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + } + + // Convert data type from input to output + t_outp[outp_bufi] = OUT_T(t_inp[inp_bufi]); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml new file mode 100644 index 00000000000..11d56cad4a9 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/view_convert_buffer.yaml @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +view_convert_buffer: + parameter_names_with_default_values: + IN_DTYPE: float + OUT_DTYPE: float + STORAGE: buffer + generate_variant_forall: + combination: + parameter_names: [IN_DTYPE, OUT_DTYPE] + combos: + - parameter_values: [int32, float] + - parameter_values: [int32, half] + - parameter_values: [uint8, float] + - parameter_values: [uint8, half] + - parameter_values: [uint8, int32] + shader_variants: + - NAME: view_convert_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp index fcc8fe4b265..e1914f350b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/RotaryEmbedding.cpp @@ -43,10 +43,17 @@ utils::uvec3 rotary_embedding_global_wg_size( const ValueRef xq_out = args.at(0).refs.at(0); - utils::uvec3 global_wg_size = graph->logical_limits_of(xq_out); - global_wg_size[0] /= 2; + // Head dim texel size + const uint32_t D4 = utils::div_up_4(graph->size_at(-1, xq_out)); + // Divide by 2 since each invocation computes 2 output locations + const uint32_t D8 = utils::div_up(D4, uint32_t(2)); - return global_wg_size; + // Number of query heads + const uint32_t QH = graph->size_at(-2, xq_out); + // Input tokens sequence length + const uint32_t S = graph->size_at(-3, xq_out); + + return {D8, QH, S}; } void add_rotary_embedding_node( @@ -73,8 +80,14 @@ void add_rotary_embedding_node( VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin)); std::string kernel_name = "rotary_embedding"; + add_storage_type_suffix(kernel_name, graph.storage_type_of(xq_out)); add_dtype_suffix(kernel_name, graph.dtype_of(xq_out)); + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(xq_out), + graph.meta_ubo(xk_out), + graph.meta_ubo(freqs_cos)}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -84,7 +97,7 @@ void add_rotary_embedding_node( {{{xq_out, xk_out}, vkapi::kWrite}, {{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}}, // Parameter buffers - {graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)}, + param_ubos, // Push Constants {}, // Specialization Constants diff --git a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp index 4eed8b82834..d28d2c90fcb 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SDPA.cpp @@ -471,10 +471,31 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { VK_CHECK_COND(graph.val_is_none(attn_mask)); const int64_t num_q_heads = graph.size_at(-2, q_projected); - const int64_t max_seq_len = graph.size_at(-3, q_projected); - + int64_t max_seq_len = graph.size_at(-3, q_projected); const int64_t max_context_len = graph.size_at(-3, k_cache); + const utils::StorageType attn_weights_storage = + graph.storage_type_of(q_projected); + + // If using buffer storage for attn weights, we need to ensure that the buffer + // numel limit is not exceeded. If needed, manually adjust max_seq_len based + // on the buffer numel limit. + if (attn_weights_storage == utils::kBuffer) { + const int64_t max_buffer_numel = graph.max_buffer_numel(); + if (num_q_heads * max_seq_len * max_context_len >= max_buffer_numel) { + // Compute the maximum possible value for max_seq_len that will hit + // the buffer numel limit. + max_seq_len = max_buffer_numel / (num_q_heads * max_context_len); + // Adjust down to the nearest multiple of 4 to make sure the limit is + // not hit. + if (max_seq_len % 4 != 0) { + max_seq_len = (max_seq_len / 4) * 4; + } else { + max_seq_len -= 4; + } + } + } + std::vector attn_weight_full_sizes = { 1, // batch num_q_heads, @@ -485,14 +506,14 @@ void sdpa_impl(ComputeGraph& graph, const std::vector& args) { &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); TmpTensor attn_weights_softmax( &graph, attn_weight_full_sizes, graph.dtype_of(q_projected), - graph.storage_type_of(q_projected), + attn_weights_storage, utils::kWidthPacked); add_sdpa_compute_attn_weights_node( @@ -528,9 +549,9 @@ void sdpa_with_kv_cache_impl( utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = - prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = - prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked); + graph.add_tensor_like(v_cache_data, cache_storage, utils::kWidthPacked); update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1}); update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1}); @@ -573,7 +594,7 @@ void compute_attn_weight_with_kv_cache_impl( (void)sequence_len; - utils::StorageType cache_storage = graph.storage_type_of(q_projected); + const utils::StorageType cache_storage = graph.storage_type_of(q_projected); const ValueRef k_cache = graph.add_tensor_like(k_cache_data, cache_storage, utils::kWidthPacked); const ValueRef v_cache = diff --git a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp index 36a8ee4c3b1..602fe1ef129 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Unsqueeze.cpp @@ -67,6 +67,18 @@ void resize_unsqueeze_node( std::vector out_sizes = graph->sizes_of(in); + std::vector unsqueezed_dims; + + if (graph->val_is_int_list(dims_ref)) { + const IntListPtr dims = graph->get_int_list(dims_ref); + for (int64_t d : *dims) { + unsqueezed_dims.push_back(d); + } + } else { + const int64_t dim = graph->extract_scalar(dims_ref); + unsqueezed_dims.push_back(dim); + } + // Insert singleton dimensions at the specified positions for (auto dim : dims_vec) { int64_t d = dim; diff --git a/backends/vulkan/runtime/graph/ops/impl/View.cpp b/backends/vulkan/runtime/graph/ops/impl/View.cpp index 8701a6246b0..5e2c898573a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/View.cpp @@ -60,6 +60,16 @@ void resize_view_node( } } +void resize_to_dim_order_copy_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const std::vector in_sizes = graph->sizes_of(in); + graph->virtual_resize(out, in_sizes); +} + void add_view_node( ComputeGraph& graph, ValueRef in, @@ -98,6 +108,11 @@ void add_view_copy_buffer_node( std::string kernel_name = "view_buffer"; add_dtype_suffix(kernel_name, graph.dtype_of(out)); + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -110,7 +125,41 @@ void add_view_copy_buffer_node( // Push Constants {}, // Specialization Constants + {all_contiguous_int}, + // Resize Args + resize_args, + // Resizing Logic + resize_fn)); +} + +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn) { + std::string kernel_name = "view_convert_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(in)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + bool all_contiguous = graph.is_contiguous_buffer_tensor(in) && + graph.is_contiguous_buffer_tensor(out); + + int32_t all_contiguous_int = all_contiguous ? 1 : 0; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter Buffers + {graph.buffer_meta_ubo(out), graph.buffer_meta_ubo(in)}, + // Push Constants {}, + // Specialization Constants + {all_contiguous_int}, // Resize Args resize_args, // Resizing Logic @@ -132,8 +181,38 @@ void view(ComputeGraph& graph, const std::vector& args) { return add_view_node(graph, in, sizes, out); } +void to_dim_order_copy(ComputeGraph& graph, const std::vector& args) { + int args_idx = 0; + const ValueRef in = args.at(args_idx++); + const ValueRef dtype = args.at(args_idx++); + (void)dtype; + const ValueRef layout = args.at(args_idx++); + (void)layout; + const ValueRef device = args.at(args_idx++); + (void)device; + const ValueRef pin_memory = args.at(args_idx++); + (void)pin_memory; + const ValueRef non_blocking = args.at(args_idx++); + (void)non_blocking; + const ValueRef dim_order = args.at(args_idx++); + (void)dim_order; + + const ValueRef out = args.at(args_idx++); + + VK_CHECK_COND(graph.is_buffer_storage(in) && graph.is_buffer_storage(out)); + + if (graph.dtype_of(in) == graph.dtype_of(out)) { + return add_view_copy_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); + } + + return add_view_copy_convert_buffer_node( + graph, in, out, {}, resize_to_dim_order_copy_node); +} + REGISTER_OPERATORS { VK_REGISTER_OP(aten.view_copy.default, view); + VK_REGISTER_OP(dim_order_ops._to_dim_order_copy.default, to_dim_order_copy); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/View.h b/backends/vulkan/runtime/graph/ops/impl/View.h index 7a7a8d57742..c8e52492417 100644 --- a/backends/vulkan/runtime/graph/ops/impl/View.h +++ b/backends/vulkan/runtime/graph/ops/impl/View.h @@ -24,6 +24,19 @@ void add_view_copy_buffer_node( const std::vector& resize_args, const ExecuteNode::ResizeFunction& resize_fn); +/* + * Dispatches the view_convert_buffer compute shader. This can be used to + * implement ops that preserve the "contiguous" indexes of elements between the + * input and output while converting between different data types such as + * view_copy with dtype conversion. + */ +void add_view_copy_convert_buffer_node( + ComputeGraph& graph, + ValueRef in, + ValueRef out, + const std::vector& resize_args, + const ExecuteNode::ResizeFunction& resize_fn); + void add_view_node( ComputeGraph& graph, ValueRef in, diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index f92cea64767..f38c510a8b1 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -11,20 +11,14 @@ from typing import Tuple import executorch.backends.vulkan.test.utils as test_utils - import torch - from executorch.backends.transforms.convert_dtype_pass import I64toI32 - from executorch.backends.vulkan.partitioner.vulkan_partitioner import VulkanPartitioner - from executorch.backends.vulkan.vulkan_preprocess import VulkanBackend - from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( get_symmetric_quantization_config, XNNPACKQuantizer, ) - from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, @@ -36,11 +30,8 @@ ) from executorch.extension.pytree import tree_flatten from torch.export import Dim, export, ExportedProgram - from torchao.quantization.granularity import PerGroup - from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e - from torchao.quantization.pt2e.quantizer import Quantizer from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ from torchao.utils import unwrap_tensor_subclass @@ -69,9 +60,6 @@ def lower_module( edge_program = to_edge_transform_and_lower( program, compile_config=edge_compile_config, - transform_passes=[ - I64toI32(edge_compile_config._skip_dim_order), - ], partitioner=[VulkanPartitioner(compile_options)], ) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 57863703498..81ee67a596c 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -7,11 +7,9 @@ # pyre-strict from functools import partial - from typing import Any, Callable, Dict, final, List import executorch.backends.vulkan.utils as utils - from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform @@ -29,7 +27,6 @@ ) from executorch.backends.vulkan._passes.fuse_patterns import FusePatternsPass from executorch.backends.vulkan._passes.remove_asserts import RemoveAssertsTransform - from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder from executorch.backends.vulkan.serialization.vulkan_graph_schema import ( VkMemoryLayout, @@ -39,7 +36,6 @@ serialize_vulkan_graph, ) from executorch.backends.xnnpack._passes import FuseBatchNormPass - from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -47,18 +43,12 @@ PreprocessResult, ) from executorch.exir.backend.utils import DelegateMappingBuilder - from executorch.exir.memory_planning import greedy, MemoryPlanningAlgorithmSuite from executorch.exir.pass_base import ExportPass, PassBase - from executorch.exir.passes import MemoryPlanningPass, SpecPropPass - from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass - from executorch.exir.program._program import _transform - from torch._export.verifier import Verifier - from torch.export._remove_auto_functionalized_pass import ( unsafe_remove_auto_functionalized_pass, ) @@ -209,6 +199,7 @@ def preprocess( # noqa: C901 texture_limits, default_storage_type=default_storage_type, default_memory_layout=default_memory_layout, + force_fp16=force_fp16, ), ], )