diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index be44679f3b0..578898ad194 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -652,6 +652,17 @@ void vTensorStorage::transition( last_access_.access = cur_access; } +bool vTensorStorage::is_copy_of(const vTensorStorage& other) const { + if (storage_type_ != other.storage_type_) { + return false; + } + // Copies are only enabled for buffer storage at the moment + if (storage_type_ != utils::kBuffer) { + return false; + } + return buffer_.is_copy_of(other.buffer_); +} + void vTensorStorage::discard_and_reallocate( const std::vector& padded_sizes, const utils::GPUMemoryLayout gpu_memory_layout, diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 8186ef1bd66..d37628e4adc 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -152,6 +152,11 @@ class vTensorStorage final { return image_.format(); } + /* + * Used for checking if this vTensorStorage is a copy of another instance + */ + bool is_copy_of(const vTensorStorage& other) const; + void discard_and_reallocate( const std::vector& padded_sizes, const utils::GPUMemoryLayout gpu_memory_layout, @@ -458,6 +463,13 @@ class vTensor final { * tensor sizes */ void reallocate(const std::vector& new_sizes); + + /* + * Check if this vTensor instance is a view of another vTensor instance + */ + inline bool is_view_of(const vTensor& other) const { + return storage_.is_copy_of(other.storage_); + } }; } // namespace api diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 48e1ebf0a83..9fa0091b298 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -368,7 +368,7 @@ utils::uvec3 ComputeGraph::create_local_wg_size( } utils::uvec3 ComputeGraph::create_local_wg_size(const ValueRef idx) { - return create_local_wg_size(image_extents_of(idx)); + return create_local_wg_size(create_global_wg_size(idx)); } void ComputeGraph::copy_into_staging( diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index faa2f4107ec..5740d24a448 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -186,6 +186,21 @@ class ComputeGraph final { std::vector sizes_of(const ValueRef idx) const; + /* + * Returns the size of the tensor at `idx` along the specified dimension. + * Negative indexing is allowed. + */ + template + T size_at(const int64_t dim, const ValueRef idx) const { + const Value& val = values_.at(idx); + if (val.isTensor()) { + return static_cast(utils::val_at(dim, val.toConstTensor().sizes())); + } else if (val.isTensorRef()) { + return static_cast(utils::val_at(dim, val.toConstTensorRef().sizes)); + } + VK_THROW("Could not get sizes of value with type ", val.type()); + } + vkapi::ScalarType dtype_of(const ValueRef idx) const; inline utils::uvec3 image_extents_of(const ValueRef idx) const { @@ -204,6 +219,13 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().has_buffer_storage(); } + inline bool val_is_view_of(const ValueRef maybe_view, const ValueRef base) + const { + return values_.at(maybe_view) + .toConstTensor() + .is_view_of(values_.at(base).toConstTensor()); + } + inline utils::GPUMemoryLayout memory_layout_of(const ValueRef idx) const { return values_.at(idx).toConstTensor().gpu_memory_layout(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl new file mode 100644 index 00000000000..25a6a742779 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl @@ -0,0 +1,60 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_required_extensions(DTYPE)} + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")} +${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")} +${layout_declare_ubo(3, "ivec4", "out_sizes")} +${layout_declare_ubo(4, "ivec4", "out_strides")} +${layout_declare_ubo(5, "ivec4", "mat1_sizes")} +${layout_declare_ubo(6, "ivec4", "mat1_strides")} +${layout_declare_ubo(7, "ivec4", "mat2_sizes")} +${layout_declare_ubo(8, "ivec4", "mat2_strides")} +${layout_declare_ubo(9, "int", "out_numel")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +void main() { + const ivec4 out_idx = ivec4( + gl_GlobalInvocationID.x, + gl_GlobalInvocationID.y, + gl_GlobalInvocationID.z % out_sizes.z, + gl_GlobalInvocationID.z / out_sizes.z); + + if (any(greaterThanEqual(out_idx, out_sizes))) { + return; + } + + int mat1_id = to_buffer_id( + ivec4(0, out_idx.y, out_idx.z, out_idx.w), mat1_strides); + int mat2_id = to_buffer_id( + ivec4(out_idx.x, 0, out_idx.z, out_idx.w), mat2_strides); + + T sum = T(0.0); + for (int i = 0; i < mat1_sizes.x; ++i) { + sum += t_mat1[mat1_id] * t_mat2[mat2_id]; + + mat1_id += mat1_strides.x; + mat2_id += mat2_strides.y; + } + + const int out_id = to_buffer_id(out_idx, out_strides); + t_out[out_id] = T(sum); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml new file mode 100644 index 00000000000..54eb444f73d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +matmul_naive_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: float + - VALUE: half + shader_variants: + - NAME: matmul_naive_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.glsl similarity index 72% rename from backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl rename to backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.glsl index 37a9b60f3c5..7225f2c64a0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.glsl @@ -16,17 +16,11 @@ $if MAT2_IS_TRANSPOSED: #include "indexing_utils.h" #include "matmul.h" -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; -layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict InSizes { - ivec4 in_sizes; -}; +${layout_declare_tensor(0, "w", "im_out", DTYPE, "texture3d")} +${layout_declare_tensor(1, "r", "im_mat1", DTYPE, "texture3d")} +${layout_declare_tensor(2, "r", "im_mat2", DTYPE, "texture3d")} +${layout_declare_ubo(3, "ivec3", "out_limits")} +${layout_declare_ubo(4, "ivec4", "in_sizes")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.yaml similarity index 71% rename from backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml rename to backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.yaml index 1c4db3f0ce9..bb1eed494a5 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_texture3d.yaml @@ -4,10 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -matmul_naive: +matmul_naive_texture3d: parameter_names_with_default_values: DTYPE: float - NDIM: 3 + STORAGE: texture3d MAT1_PACKING: W_packed MAT2_PACKING: H_packed MAT2_IS_TRANSPOSED: false @@ -16,9 +16,9 @@ matmul_naive: - VALUE: float - VALUE: half shader_variants: - - NAME: matmul_naive_W_packed_H_packed - - NAME: matmul_naive_W_packed_W_packed + - NAME: matmul_naive_texture3d_W_packed_H_packed + - NAME: matmul_naive_texture3d_W_packed_W_packed MAT2_PACKING: W_packed - - NAME: matmul_transposed_naive_W_packed_W_packed + - NAME: matmul_transposed_naive_texture3d_W_packed_W_packed MAT2_PACKING: W_packed MAT2_IS_TRANSPOSED: true diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index d1d3ad47d76..a25a602e38f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -62,7 +62,48 @@ void resize_matmul_node( out->virtual_resize(new_out_sizes); } -void add_matmul_naive_node( +void add_matmul_naive_buffer_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef mat2_data, + const ValueRef out, + const ValueRef mat2_is_transposed) { + ValueRef mat2 = prepack_if_tensor_ref(graph, mat2_data, utils::kHeightPacked); + + std::string kernel_name = "matmul_naive_buffer"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + utils::uvec3 global_size = { + graph.size_at(-1, out), + graph.size_at(-2, out), + graph.size_at(-3, out) * graph.size_at(-4, out)}; + + graph.execute_nodes().emplace_back(new ExecuteNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + graph.create_local_wg_size(global_size), + // Inputs and Outputs + {{out, vkapi::MemoryAccessType::WRITE}, + {{mat1, mat2}, vkapi::MemoryAccessType::READ}}, + // Shader params buffers + { + graph.sizes_ubo(out), + graph.strides_ubo(out), + graph.sizes_ubo(mat1), + graph.strides_ubo(mat1), + graph.sizes_ubo(mat2), + graph.strides_ubo(mat2), + graph.numel_ubo(out), + }, + // Specialization Constants + {}, + // Resizing Logic + resize_matmul_node, + {mat2_is_transposed})); +} + +void add_matmul_naive_texture3d_node( ComputeGraph& graph, const ValueRef mat1, const ValueRef mat2_data, @@ -74,6 +115,7 @@ void add_matmul_naive_node( ? "matmul_transposed_naive" : "matmul_naive"; kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat1)); add_memory_layout_suffix(kernel_name, graph.memory_layout_of(mat2)); add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -174,12 +216,16 @@ void add_matmul_node( const ValueRef mat2_data, const ValueRef out, const ValueRef mat2_is_transposed) { - if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) { + if (graph.is_buffer_storage(out)) { + add_matmul_naive_buffer_node( + graph, mat1, mat2_data, out, mat2_is_transposed); + } else if (graph.memory_layout_of(mat1) == utils::kChannelsPacked) { add_matmul_optimized_node(graph, mat1, mat2_data, out, mat2_is_transposed); } else if (graph.memory_layout_of(mat1) == utils::kWidthPacked) { - add_matmul_naive_node(graph, mat1, mat2_data, out, mat2_is_transposed); + add_matmul_naive_texture3d_node( + graph, mat1, mat2_data, out, mat2_is_transposed); } else { - VK_THROW("Input should be channel packed or width packed."); + VK_THROW("Input texture should be channel packed or width packed."); } } diff --git a/backends/vulkan/runtime/vk_api/memory/Buffer.h b/backends/vulkan/runtime/vk_api/memory/Buffer.h index 3f69d1f2237..9302048f861 100644 --- a/backends/vulkan/runtime/vk_api/memory/Buffer.h +++ b/backends/vulkan/runtime/vk_api/memory/Buffer.h @@ -150,6 +150,10 @@ class VulkanBuffer final { return (handle_ != VK_NULL_HANDLE); } + inline bool is_copy_of(const VulkanBuffer& other) const { + return (handle_ == other.handle_) && is_copy_; + } + inline void bind_allocation(const Allocation& memory) { VK_CHECK_COND(!memory_, "Cannot bind an already bound allocation!"); VK_CHECK(vmaBindBufferMemory(allocator_, memory.allocation, handle_)); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index ff5c7a60e0f..7f9f1842adf 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -70,6 +70,7 @@ def get_mm_inputs(): test_suite.prepacked_args = ["mat2"] # ATen matmul doesn't support half test_suite.dtypes = ["at::kFloat"] + test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] test_suite.layouts = [ "utils::kWidthPacked", "utils::kChannelsPacked", diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 1ac74e29ef4..e24e2ea4e06 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -611,6 +611,7 @@ TEST_F(VulkanComputeAPITest, tensor_copy_test) { vTensor original = CREATE_FLOAT_BUFFER(sizes, /*allocate_memory=*/true); vTensor copy = vTensor(original, sizes, dim_order); EXPECT_TRUE(get_vma_allocation_count() == 1); + EXPECT_TRUE(copy.is_view_of(original)); // Fill original tensor with some data fill_vtensor(original, 2.5f, true); @@ -1166,6 +1167,8 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_view) { ValueRef slice = graph.add_tensor_view(orig.value, slice_sizes, dim_order, offset); + EXPECT_TRUE(graph.val_is_view_of(slice, orig.value)); + IOValueRef out = {}; out.value = graph.add_tensor(slice_sizes, vkapi::kFloat); @@ -2282,24 +2285,28 @@ void test_binary_op( } } -#define CALL_TEST_FN_FORALL_CONDITIONS(_) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, false) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, false) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, false) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, true) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_HEIGHT_PACKED, true) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, true) - -#define CALL_TEST_FN_FOR_W_PACKED(_) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, false) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_WIDTH_PACKED, true) - -#define CALL_TEST_FN_FOR_C_PACKED(_) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, false) \ - _(vkapi::kFloat, utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED, true) +#define CALL_TEST_FN_FORALL_CONDITIONS(_) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, false) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kHeightPacked, true) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true) + +#define CALL_TEST_FN_FOR_W_PACKED(_) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, false) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kWidthPacked, true) \ + _(vkapi::kFloat, utils::kBuffer, utils::kWidthPacked, false) \ + _(vkapi::kFloat, utils::kBuffer, utils::kWidthPacked, true) + +#define CALL_TEST_FN_FOR_C_PACKED(_) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, false) \ + _(vkapi::kFloat, utils::kTexture3D, utils::kChannelsPacked, true) \ + _(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, false) \ + _(vkapi::kFloat, utils::kBuffer, utils::kChannelsPacked, true) TEST(VulkanComputeGraphOpsTest, add_smoke_test) { -#define RUN_TESTS(dtype, layout, prepack) \ +#define RUN_TESTS(dtype, storage, layout, prepack) \ test_binary_op("add", {17, 21}, {17, 21}, dtype, layout, prepack); \ test_binary_op("add", {17, 21}, {1, 1}, dtype, layout, prepack); \ test_binary_op("sub", {11, 22}, {11, 22}, dtype, layout, prepack); \ @@ -2320,9 +2327,11 @@ void test_mm( int K, int N, vkapi::ScalarType dtype, + utils::StorageType storage_type, utils::GPUMemoryLayout memory_layout, bool prepack = true) { GraphConfig config; + config.set_storage_type_override(storage_type); ComputeGraph graph(config); std::vector mat1_size = {M, K}; @@ -2379,38 +2388,42 @@ void test_mm( } TEST(VulkanComputeGraphOpsTest, mm_smoke_test) { -#define RUN_TESTS(dtype, layout, prepack) \ - test_mm( \ - /*B = */ 1, \ - /*M = */ 31, \ - /*K = */ 127, \ - /*N = */ 23, \ - dtype, \ - layout, \ - prepack); \ - test_mm( \ - /*B = */ 5, \ - /*M = */ 31, \ - /*K = */ 127, \ - /*N = */ 23, \ - dtype, \ - layout, \ - prepack); \ - test_mm( \ - /*B = */ 7, \ - /*M = */ 13, \ - /*K = */ 89, \ - /*N = */ 17, \ - dtype, \ - layout, \ - prepack); \ - test_mm( \ - /*B = */ 1, \ - /*M = */ 13, \ - /*K = */ 89, \ - /*N = */ 17, \ - dtype, \ - layout, \ +#define RUN_TESTS(dtype, storage_type, layout, prepack) \ + test_mm( \ + /*B = */ 1, \ + /*M = */ 31, \ + /*K = */ 127, \ + /*N = */ 23, \ + dtype, \ + storage_type, \ + layout, \ + prepack); \ + test_mm( \ + /*B = */ 5, \ + /*M = */ 31, \ + /*K = */ 127, \ + /*N = */ 23, \ + dtype, \ + storage_type, \ + layout, \ + prepack); \ + test_mm( \ + /*B = */ 7, \ + /*M = */ 13, \ + /*K = */ 89, \ + /*N = */ 17, \ + dtype, \ + storage_type, \ + layout, \ + prepack); \ + test_mm( \ + /*B = */ 1, \ + /*M = */ 13, \ + /*K = */ 89, \ + /*N = */ 17, \ + dtype, \ + storage_type, \ + layout, \ prepack); CALL_TEST_FN_FOR_W_PACKED(RUN_TESTS);