diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 79448beda65..55c36463b51 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -515,6 +515,11 @@ def register_view_ops_with_buffer_meta(): ) +@update_features(exir_ops.edge.aten.expand_copy.default) +def register_expand(): + return OpFeatures(inputs_storage=utils.ANY_BUFFER, supports_resize=False) + + # Fully featured transfer operators (i.e. operators that copy data from the input # tensor(s) to the output tensor(s)), which have memory layout agnostic implementations # for both texture and buffer storage types. diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.glsl new file mode 100644 index 00000000000..ce433040b66 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.glsl @@ -0,0 +1,51 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_outp", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_inp", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const uint outp_bufi = gl_GlobalInvocationID.x; + if (outp_bufi >= numel(outp)) { + return; + } + + TensorIndex outp_tidx; + linear_idx_to_tensor_idx(outp, outp_bufi, outp_tidx); + + // Map output tensor index to input tensor index by taking modulo + // with input tensor sizes for each dimension + TensorIndex inp_tidx = outp_tidx; + for (int d = 0; d < ndim(inp); ++d) { + uint inp_size = size_at(inp, d); + uint outp_idx = idx_at(outp_tidx, d); + inp_tidx.data[div_4(d)][mod_4(d)] = outp_idx % inp_size; + } + + const uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); + // Copy data from input to output + t_outp[outp_bufi] = t_inp[inp_bufi]; +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml new file mode 100644 index 00000000000..6d90e1fa8b1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/expand_buffer.yaml @@ -0,0 +1,10 @@ +expand_buffer: + parameter_names_with_default_values: + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + - VALUE: int32 + shader_variants: + - NAME: expand_buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/Expand.cpp b/backends/vulkan/runtime/graph/ops/impl/Expand.cpp new file mode 100644 index 00000000000..1623a26b2a1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Expand.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include +#include +#include + +namespace vkcompute { + +void add_expand_buffer_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef size, + const ValueRef out) { + std::string kernel_name = "expand"; + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + vkapi::ParamsBindList param_buffers = { + graph.buffer_meta_ubo(out), + graph.buffer_meta_ubo(in), + }; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + default_pick_global_wg_size, + default_pick_local_wg_size, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Parameter buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {size}, + // Resizing Logic + nullptr)); +} + +void expand(ComputeGraph& graph, const std::vector& args) { + int idx = 0; + const ValueRef in = args.at(idx++); + const ValueRef size = args.at(idx++); + const ValueRef implicit = args.at(idx++); + (void)implicit; + const ValueRef out = args.at(idx++); + + if (graph.is_buffer_storage(out)) { + return add_expand_buffer_node(graph, in, size, out); + } + + VK_THROW("Expand operator only supports buffer storage"); +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(aten.expand_copy.default, expand); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index e04ad80aa86..cb29d836056 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1880,6 +1880,48 @@ def get_flip_inputs(): return test_suite +@register_test_suite("aten.expand_copy.default") +def get_expand_inputs(): + test_suite = VkTestSuite( + [ + # Basic expansion cases + ((1,), [5]), + ((1, 1), [3, 4]), + ((1, 3), [2, 3]), + ((3, 1), [3, 4]), + ((1, 1, 1), [2, 3, 4]), + # Expand with same size (no-op) + ((3, 4), [3, 4]), + ((2, 3, 4), [2, 3, 4]), + # Expand with additional dimensions + ((3,), [2, 3]), + ((3, 4), [2, 3, 4]), + ((2, 3), [1, 2, 3]), + # Mixed expansion cases + ((1, 3, 1, 4), [2, 3, 5, 4]), + ((1, 1, 3, 1), [2, 4, 3, 5]), + # Larger tensor cases + ((1, S1), [M, S1]), + ((S2, 1), [S2, M1]), + ((1, 1, S), [S1, S2, S]), + ((1, S1, 1, S2), [M, S1, M1, S2]), + ] + ) + test_suite.storage_types = [ + "utils::kBuffer", + ] + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.dtypes = [ + "at::kFloat", + "at::kHalf", + ] + test_suite.data_gen = "make_seq_tensor" + return test_suite + + @register_test_suite("aten.where.self") def get_where_inputs(): Test = namedtuple("Where", ["condition", "self", "other"]) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index d1feeb0f5ce..ee4a8bcc9fc 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -621,6 +621,7 @@ def make_filtered_tensor_repset( CHANNELS_PACKED_TEXTURE = TensorRepSet(set(), {VkMemoryLayout.TENSOR_CHANNELS_PACKED}) ANY_TEXTURE = TensorRepSet(set(), all_memory_layouts) +ANY_BUFFER = TensorRepSet(all_memory_layouts, set()) ANY_STORAGE = TensorRepSet(all_memory_layouts, all_memory_layouts) NO_STORAGE = TensorRepSet(set(), set())