From 83fd30b2bf652921475c37505d8c87fde6df49de Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 10 Nov 2025 13:32:26 -0800 Subject: [PATCH 1/4] [ET-VK][ez] Apply quantize op replacement to all argument nodes Pull Request resolved: https://github.com/pytorch/executorch/pull/15702 Title says it all! With the way the pass is currently written only the first arg will be inspected for q/dq node replacement. As a consequence, the second arg for i.e. binary ops may not have the quantized op be replaced. ghstack-source-id: 322214453 @exported-using-ghexport Differential Revision: [D86674169](https://our.internmc.facebook.com/intern/diff/D86674169/) --- backends/vulkan/_passes/replace_qdq.py | 33 +++++++++++++------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/backends/vulkan/_passes/replace_qdq.py b/backends/vulkan/_passes/replace_qdq.py index fcfcdfc4c18..2c5331eb213 100644 --- a/backends/vulkan/_passes/replace_qdq.py +++ b/backends/vulkan/_passes/replace_qdq.py @@ -32,24 +32,23 @@ def call(self, graph_module: torch.fx.GraphModule): exir_ops.edge.et_vk.conv2d_q8ta_q8csw_q8to_dw.default, exir_ops.edge.et_vk.add_q8ta_q8ta_q8to.default, ]: - # Replace quantize op feeding into conv2d (first argument is the quantized input) - quantized_input_node = node.args[0] - if isinstance( - quantized_input_node, torch.fx.Node - ) and utils.is_quant_node(quantized_input_node): - # Get the arguments from the original quantize node - input_tensor = quantized_input_node.args[0] - scale = quantized_input_node.args[1] - zero_point = quantized_input_node.args[2] + for quantized_input_node in node.args: + if isinstance( + quantized_input_node, torch.fx.Node + ) and utils.is_quant_node(quantized_input_node): + # Get the arguments from the original quantize node + input_tensor = quantized_input_node.args[0] + scale = quantized_input_node.args[1] + zero_point = quantized_input_node.args[2] - nodes_to_replace.append( - { - "old_node": quantized_input_node, - "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, - "args": (input_tensor, scale, zero_point), - "node_type": "quantize_input", - } - ) + nodes_to_replace.append( + { + "old_node": quantized_input_node, + "new_target": exir_ops.edge.et_vk.quantize_q8ta_for_conv2d.default, + "args": (input_tensor, scale, zero_point), + "node_type": "quantize_input", + } + ) # Find dequantize ops that consume the output of this conv2d for user in node.users: From 164cec7c1bb496f9d0c96646647cf311fb63ca67 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 10 Nov 2025 13:32:28 -0800 Subject: [PATCH 2/4] [ET-VK] Allow buffer input/output for quantize/dequantize for conv2d ops Pull Request resolved: https://github.com/pytorch/executorch/pull/15703 Title says it all! This diff allows quantize/dequantize ops to consume/produce tensors in the `CONTIGUOUS_BUFFER` layout. This can help reduce the number of memory layout transitions needed to execute a model. ghstack-source-id: 322214457 @exported-using-ghexport Differential Revision: [D86674166](https://our.internmc.facebook.com/intern/diff/D86674166/) --- backends/vulkan/op_registry.py | 4 +-- .../ops/glsl/conv2d_fp_input_tile_load.glslh | 18 ++++++++++++- .../quantize_and_pack_q8ta_conv2d_input.glsl | 2 +- .../quantize_and_pack_q8ta_conv2d_input.yaml | 1 + ...ack_and_dequantize_q8ta_conv2d_output.glsl | 18 +++++++++++-- ...ack_and_dequantize_q8ta_conv2d_output.yaml | 1 + .../custom_ops/q8ta_q8csw_q8to_conv2d.cpp | 16 ++++++++---- .../test/custom_ops/q8ta_q8ta_q8to_add.cpp | 25 +++++++------------ backends/vulkan/utils.py | 8 ++++++ 9 files changed, 66 insertions(+), 27 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ef41060272c..403e747141f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -636,7 +636,7 @@ def register_quantized_binary_op(): def register_quantize_for_conv2d_op(): return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], outputs_storage=[ utils.PACKED_INT8_4W4C_BUFFER, @@ -656,7 +656,7 @@ def register_dequantize_for_conv2d_op(): utils.PACKED_INT8_4W4C_BUFFER, ], outputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, + utils.CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER, ], supports_resize=False, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh index be8a76421a5..a3934422e27 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_input_tile_load.glslh @@ -14,7 +14,21 @@ #include "linear_fp_input_tile.glslh" VEC4_T load_fp_input_texel(const Conv2dTensorIndex tidx) { +#ifdef INPUT_BUFFER + VEC4_T texel = VEC4_T(0); + const int c_idx = mul_4(tidx.data.z); + const int c_stride = input_sizes.y * input_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * input_sizes.x + tidx.data.x; + const int limit = min(input_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; i++) { + texel[i] = t_fp_input[base_buf_i + i * c_stride]; + } + return texel; +#else return texelFetch(t_fp_input, tidx.data, 0); +#endif } void load_fp_input_tile( @@ -23,7 +37,9 @@ void load_fp_input_tile( #if TILE_M == 4 && TILE_K4 == 1 Conv2dTensorIndex load_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < TILE_M; w++) { - tile.data[w][0] = load_fp_input_texel(load_tidx); + if (load_tidx.data.x < input_sizes.x) { + tile.data[w][0] = load_fp_input_texel(load_tidx); + } load_tidx.data.x++; } #else diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl index d485523709b..dfa0b5a95bf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.glsl @@ -31,7 +31,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" ${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_fp_input", DTYPE, INPUT_STORAGE)} ${layout_declare_ubo(B, "ivec4", "input_sizes")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml index 712d3156e2e..929567a2595 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_q8ta_conv2d_input.yaml @@ -15,6 +15,7 @@ quantize_and_pack_q8ta_conv2d_input: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [buffer, texture3d] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl index 798366b523a..be0a39bac3c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.glsl @@ -30,7 +30,7 @@ layout(std430) buffer; #include "conv2d_common.glslh" -${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_fp_output", DTYPE, OUTPUT_STORAGE)} ${layout_declare_tensor(B, "r", "t_packed_int8_output", "int", INPUT_STORAGE, is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "output_sizes")} @@ -84,7 +84,19 @@ void unpack_and_dequantize( void store_fp_output_texel( const Conv2dTensorIndex tidx, const VEC4_T out_texel) { +#ifdef OUTPUT_BUFFER + const int c_idx = mul_4(tidx.data.z); + const int c_stride = output_sizes.y * output_sizes.x; + + const int base_buf_i = c_idx * c_stride + tidx.data.y * output_sizes.x + tidx.data.x; + const int limit = min(output_sizes.z - c_idx, 4); + + for (int i = 0; i < limit; ++i) { + t_fp_output[base_buf_i + i * c_stride] = out_texel[i]; + } +#else imageStore(t_fp_output, tidx.data, out_texel); +#endif } void store_fp_tile( @@ -92,7 +104,9 @@ void store_fp_tile( const Conv2dBlockIndex block_idx) { Conv2dTensorIndex store_tidx = block_idx_to_tensor_idx(block_idx); [[unroll]] for (int w = 0; w < 4; w++) { - store_fp_output_texel(store_tidx, block.data[w][0]); + if (store_tidx.data.x < output_sizes.x) { + store_fp_output_texel(store_tidx, block.data[w][0]); + } store_tidx.data.x++; } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml index 24b253da343..ff1da144027 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unpack_and_dequantize_q8ta_conv2d_output.yaml @@ -15,6 +15,7 @@ unpack_and_dequantize_q8ta_conv2d_output: combos: - parameter_values: [texture3d, texture3d] - parameter_values: [texture3d, buffer] + - parameter_values: [buffer, buffer] DTYPE: - VALUE: float shader_variants: diff --git a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp index 8762fe4c0d1..bbd4af7579c 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8csw_q8to_conv2d.cpp @@ -47,11 +47,15 @@ TestCase create_test_case_from_config( std::vector input_size = { 1, config.channels.in, config.input_size.h, config.input_size.w}; + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + ValueSpec input_tensor( input_size, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::RANDOM); if (debugging()) { @@ -139,7 +143,7 @@ TestCase create_test_case_from_config( {1, config.channels.out, H_out, W_out}, input_dtype, storage_type, - utils::kChannelsPacked, + io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8csw_q8to operation @@ -182,7 +186,8 @@ std::vector generate_quantized_conv2d_easy_cases() { config.op_name = "conv2d_q8ta_q8csw_q8to"; // Test with both storage types and data types for completeness - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; std::vector float_types = {vkapi::kFloat}; // Generate test cases for each combination @@ -341,7 +346,8 @@ std::vector generate_quantized_conv2d_test_cases() { 4}}; // Test with different storage types and data types - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Generate test cases for each combination for (auto& config : configs) { @@ -621,7 +627,7 @@ int main(int argc, char* argv[]) { quantized_conv2d_flop_calculator, "QuantizedConv2dQ8ToQ8To", 0, - 10, + 1, ref_fn); return 0; diff --git a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp index 5799bc194c9..eb8e6908060 100644 --- a/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp +++ b/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add.cpp @@ -38,21 +38,17 @@ TestCase create_quantized_add_test_case( // Set the operator name for the test case test_case.set_operator_name("et_vk.add_q8ta_q8ta_q8to.test"); + utils::GPUMemoryLayout io_memory_layout = storage_type == utils::kBuffer + ? utils::kWidthPacked + : utils::kChannelsPacked; + // Input tensor A (float/half) ValueSpec input_a( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Input tensor B (float/half) ValueSpec input_b( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::RANDOM); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::RANDOM); // Quantization parameters for input A float input_a_scale_val = 0.007843; // 2/255 approximately @@ -81,11 +77,7 @@ TestCase create_quantized_add_test_case( // Output tensor (float/half) ValueSpec output( - sizes, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::ZEROS); + sizes, input_dtype, storage_type, io_memory_layout, DataGenType::ZEROS); // Add all specs to test case for q8ta_q8ta_q8to add operation test_case.add_input_spec(input_a); @@ -119,7 +111,8 @@ std::vector generate_quantized_add_test_cases() { }; // Storage types to test - std::vector storage_types = {utils::kTexture3D}; + std::vector storage_types = { + utils::kTexture3D, utils::kBuffer}; // Data types to test std::vector data_types = {vkapi::kFloat}; diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 9c527cbc36a..1fab11f1ac7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -772,6 +772,14 @@ def make_filtered_tensor_repset( HEIGHT_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_HEIGHT_PACKED}) CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) +CHANNELS_PACKED_ANY = TensorRepSet( + {VkMemoryLayout.TENSOR_CHANNELS_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + +CHANNELS_PACKED_TEXTURE_OR_CONTIGUOUS_BUFFER = TensorRepSet( + {VkMemoryLayout.TENSOR_WIDTH_PACKED}, {VkMemoryLayout.TENSOR_CHANNELS_PACKED} +) + ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) ANY_BUFFER = TensorRepSet(all_memory_layouts, set()) From 27bed93dcdd6a95e17b34532d0a42d7ec08fcbed Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 10 Nov 2025 13:32:30 -0800 Subject: [PATCH 3/4] [ET-VK][ez] Constrain out repsets individually Pull Request resolved: https://github.com/pytorch/executorch/pull/15704 Address the TODO comment in the `tag_memory_meta_pass.py` graph pass. ``` # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there # is no need to constrain output repsets explicitly. Currently, the exceptions # (i.e. choose qparams) already define constrined repsets for the output, so # there is again no need to explicitly constrain the outputs. If an operator # appears later on that does not sync input and output representations, and # defines ambiguous repsets for the output tensor(s), then we will need to add # additional logic to this function to constrain the output repsets separately # from the input repsets. ``` This condition is now fulfilled with the below diff. ghstack-source-id: 322214459 @exported-using-ghexport Differential Revision: [D86674164](https://our.internmc.facebook.com/intern/diff/D86674164/) --- .../vulkan/_passes/tag_memory_meta_pass.py | 30 +++++++++++++------ backends/vulkan/utils.py | 19 ++++++++++++ 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 15449b98f6f..43796c043c8 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -368,7 +368,7 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = op_repsets.get_arg_repset(arg_i) if arg_repset.is_constrained(): - return arg_repset + return arg_node = op_repsets.op_node.args[arg_i] @@ -378,6 +378,20 @@ def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> No arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: + """ + Similar to the `constrain_op_arg_repset` function, but for the output repset of + the operator. + """ + out_repset = op_repsets.get_out_repset(0) + if out_repset.is_constrained(): + return + + op_node = op_repsets.op_node + out_respset = self.trace_node_users_to_constrain_repset(op_node, out_repset) + + op_repsets.try_constrain_with_out_repset(out_respset) + def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: # For most ops, constraining the argument repsets will also contrain the output # repset due to OpRepSets maintaining synchronization rules. @@ -385,14 +399,12 @@ def constrain_op_repsets(self, op_repsets: utils.OpRepSets) -> None: if utils.is_tensor_arg_node(op_repsets.op_node.args[i]): self.constrain_op_arg_repset(i, op_repsets) - # TODO(ssjia): For most ops, inputs and outputs must be synchronized, so there - # is no need to constrain output repsets explicitly. Currently, the exceptions - # (i.e. choose qparams) already define constrined repsets for the output, so - # there is again no need to explicitly constrain the outputs. If an operator - # appears later on that does not sync input and output representations, and - # defines ambiguous repsets for the output tensor(s), then we will need to add - # additional logic to this function to constrain the output repsets separately - # from the input repsets. + # However, some operators do not sync input and output representations and also + # define ambiguous repsets for the output tensor(s). In those cases we will need + # to execute additional logic to constrain the output repsets separately from + # the input repsets. + if not op_repsets.sync_primary_io_repr and op_repsets.sync_outs_repr: + self.constrain_op_out_repset(op_repsets) def set_op_node_tensor_reprs( self, graph_module: torch.fx.GraphModule, op_node: torch.fx.Node diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1fab11f1ac7..fca8173ffb7 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1106,6 +1106,25 @@ def try_constrain_with_arg_repset( self.assert_sync_contraints() return True + def try_constrain_with_out_repset(self, repset: TensorRepSet): + # Skip for operators that must synchronize the input and output representations + # or operators that have more than one output repset + if self.sync_primary_io_repr or len(self.outs_repset_list) > 1: + return False + + out_current_repset = self.outs_repset_list[0] + + if out_current_repset == repset: + return False + + if not out_current_repset.any_in_common(repset): + return False + + self.outs_repset_list[0] = out_current_repset.make_intersect(repset) + + self.assert_sync_contraints() + return True + def pick_representations(self) -> Tuple[TensorReprList, TensorReprList]: """ For each tensor participating in the op, pick a representation for it among the From 9415950304c00ed50149f37e0302dad30c775cdf Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 10 Nov 2025 13:32:32 -0800 Subject: [PATCH 4/4] [ET-VK] Implementation of gather Pull Request resolved: https://github.com/pytorch/executorch/pull/15705 Title says it all! This diff implements the gather op in ET-VK https://docs.pytorch.org/docs/stable/generated/torch.gather.html ghstack-source-id: 322214454 @exported-using-ghexport Differential Revision: [D86674167](https://our.internmc.facebook.com/intern/diff/D86674167/) --- backends/vulkan/op_registry.py | 1 + .../runtime/graph/ops/glsl/gather_buffer.glsl | 57 ++++++++++++ .../runtime/graph/ops/glsl/gather_buffer.yaml | 16 ++++ .../graph/ops/glsl/gather_texture.glsl | 67 ++++++++++++++ .../graph/ops/glsl/gather_texture.yaml | 15 ++++ .../vulkan/runtime/graph/ops/impl/Gather.cpp | 89 +++++++++++++++++++ backends/vulkan/test/op_tests/cases.py | 49 ++++++++++ .../op_tests/utils/gen_correctness_base.py | 4 +- 8 files changed, 296 insertions(+), 2 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Gather.cpp diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 403e747141f..e51bc8ea12a 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -711,6 +711,7 @@ def register_view_ops(): exir_ops.edge.aten.unsqueeze_copy.default, exir_ops.edge.aten.clone.default, exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.gather.default, ] ) def register_view_ops_with_buffer_meta(): diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl new file mode 100644 index 00000000000..318631a160f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.glsl @@ -0,0 +1,57 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_index", "int", "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "BufferMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const uint out_bufi = gl_GlobalInvocationID.x; + if (out_of_bounds(out_bufi, outp)) { + return; + } + + TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi); + + // Load the index value at the same position in the index tensor + const uint index_bufi = tensor_idx_to_linear_idx(index, out_tidx); + const int gather_idx = t_index[index_bufi]; + + // Construct the input tensor index by replacing the gather dimension + // with the gathered index value + TensorIndex input_tidx = out_tidx; + input_tidx.data[div_4(gather_dim)][mod_4(gather_dim)] = gather_idx; + + // Load from input tensor and store to output + const uint input_bufi = tensor_idx_to_linear_idx(inp, input_tidx); + + t_out[out_bufi] = t_input[input_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml new file mode 100644 index 00000000000..8138e255b58 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: gather_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl new file mode 100644 index 00000000000..71e352a7875 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.glsl @@ -0,0 +1,67 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, "texture3d")} +#define T ${texel_load_component_type(DTYPE, "texture3d")} + +${define_active_storage_type("texture3d")} +${define_required_extensions(DTYPE)} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "common.glslh" +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_input", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_index", "int", "texture3d")} + +${layout_declare_ubo(B, "TextureMetadata", "outp")} +${layout_declare_ubo(B, "TextureMetadata", "inp")} +${layout_declare_ubo(B, "TextureMetadata", "index")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int gather_dim = 0; + +void main() { + const ivec3 out_pos = ivec3(gl_GlobalInvocationID); + + if (out_of_bounds(out_pos, outp)) { + return; + } + + TensorIndex4D out_tidx = texture_pos_to_tensor4d_idx_simple(outp, out_pos); + ivec4 idx_texel = texelFetch(t_index, out_pos, 0); + + VEC4_T out_texel = VEC4_T(0); + + int limit = min( + 4, outp.sizes[outp.packed_dim] - out_tidx.data[outp.packed_dim]); + for (int comp = 0; comp < 4; comp++) { + TensorIndex4D input_tidx = out_tidx; + int gather_idx = idx_texel[comp]; + input_tidx.data[gather_dim] = gather_idx; + + TextureElementIndex input_elem_pos = tensor4d_idx_to_texture_element_idx_simple( + inp, input_tidx); + + VEC4_T input_texel = texelFetch(t_input, input_elem_pos.pos, 0); + out_texel[comp] = input_texel[input_elem_pos.comp]; + + out_tidx.data[outp.packed_dim]++; + } + + imageStore(t_out, out_pos, out_texel); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml new file mode 100644 index 00000000000..f8e26a97351 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/gather_texture.yaml @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +gather_texture: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: gather_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/Gather.cpp b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp new file mode 100644 index 00000000000..584a8d0437b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Gather.cpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +using utils::GPUMemoryLayout; +using utils::StorageType; + +void resize_gather_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef index = args.at(1).refs.at(1); + + // Output shape is the same as index shape + std::vector out_sizes = graph->sizes_of(index); + graph->virtual_resize(out, out_sizes); +} + +void add_gather_node( + ComputeGraph& graph, + const ValueRef input, + const int64_t dim, + const ValueRef index, + const ValueRef out) { + std::string kernel_name = "gather"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_ubos = { + graph.meta_ubo(out), graph.meta_ubo(input), graph.meta_ubo(index)}; + + const int64_t dim_whcn = graph.dim_of(input) - dim - 1; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{input, index}, vkapi::kRead}}, + // Shader params buffers + param_ubos, + // Push Constants + {}, + // Specialization Constants + {static_cast(dim_whcn)}, + // Resize Args + {}, + // Resizing Logic + resize_gather_node)); +} + +void gather(ComputeGraph& graph, const std::vector& args) { + ValueRef input = args[0]; + ValueRef dim_ref = args[1]; + ValueRef index = args[2]; + ValueRef out = args[4]; + + int64_t dim = graph.extract_scalar(dim_ref); + + add_gather_node(graph, input, dim, index, out); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.gather.default, gather); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index dfb9a2865ba..f59c3e30aeb 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1140,6 +1140,55 @@ def get_embedding_inputs(): return test_suite_wpack +@register_test_suite("aten.gather.default") +def get_gather_inputs(): + Test = namedtuple("GatherTest", ["input", "dim", "index"]) + Test.__new__.__defaults__ = (None, None, None) + + test_cases = [ + # Simple 2D case + Test(input=[4, 4], dim=1, index=[[1, 2], [2, 1], [3, 3], [3, 1]]), + # # 1D cases + Test(input=[10], dim=0, index=[0, 2, 5, 7, 9]), + Test(input=[8], dim=0, index=[1, 3, 5]), + # # 2D cases with different dims + Test(input=[5, 8], dim=0, index=[[0, 1], [2, 3], [4, 0]]), + Test( + input=[5, 8], + dim=1, + index=[[0, 2, 4], [1, 3, 5], [6, 7, 0], [1, 2, 3], [4, 5, 6]], + ), + # # 3D cases + Test( + input=[3, 4, 5], + dim=0, + index=[ + [[0, 1, 2, 0, 1], [1, 2, 0, 1, 2], [2, 0, 1, 2, 0], [0, 1, 2, 0, 1]] + ], + ), + Test( + input=[3, 4, 5], + dim=1, + index=[ + [[0, 1, 2, 3], [1, 2, 3, 0], [2, 3, 0, 1], [3, 0, 1, 2], [0, 1, 2, 3]] + ], + ), + Test( + input=[3, 4, 5], dim=2, index=[[[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 0]]] + ), + ] + + test_suite = VkTestSuite( + [tuple(tc) + (False, "false", "false") for tc in test_cases] + ) + + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = ["utils::kWidthPacked", "utils::kChannelsPacked"] + test_suite.storage_types = ["utils::kBuffer", "utils::kTexture3D"] + + return test_suite + + @register_test_suite("aten.unsqueeze_copy.default") def get_unsqueeze_inputs(): test_suite = VkTestSuite( diff --git a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py index 26371bc41ff..49419a50399 100644 --- a/backends/vulkan/test/op_tests/utils/gen_correctness_base.py +++ b/backends/vulkan/test/op_tests/utils/gen_correctness_base.py @@ -363,7 +363,7 @@ def generate_suite_cpp(self) -> str: static_cast(indices[0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& vec: indices) {{ acc.insert(acc.end(), vec.begin(), vec.end()); }} @@ -380,7 +380,7 @@ def generate_suite_cpp(self) -> str: static_cast(indices[0][0].size())}}; // Flatten indices as from_blob reads garbage otherwise. - std::vector acc; + std::vector acc; for (auto& v: indices) {{ for (auto& vv: v) {{ acc.insert(acc.end(), vv.begin(), vv.end());