diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl index e4064eed2fa..dd91685c8a7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl @@ -16,21 +16,23 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -${layout_declare_tensor(0, "w", "t_out", DTYPE, "buffer")} -${layout_declare_tensor(1, "r", "t_mat1", DTYPE, "buffer")} -${layout_declare_tensor(2, "r", "t_mat2", DTYPE, "buffer")} -${layout_declare_ubo(3, "ivec4", "out_sizes")} -${layout_declare_ubo(4, "ivec4", "out_strides")} -${layout_declare_ubo(5, "ivec4", "mat1_sizes")} -${layout_declare_ubo(6, "ivec4", "mat1_strides")} -${layout_declare_ubo(7, "ivec4", "mat2_sizes")} -${layout_declare_ubo(8, "ivec4", "mat2_strides")} -${layout_declare_ubo(9, "int", "out_numel")} +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "mat1_strides")} +${layout_declare_ubo(B, "ivec4", "mat2_sizes")} +${layout_declare_ubo(B, "ivec4", "mat2_strides")} +${layout_declare_ubo(B, "int", "out_numel")} #include "indexing_utils.h" layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")} + void main() { const ivec4 out_bufix = ivec4( gl_GlobalInvocationID.x, @@ -44,15 +46,28 @@ void main() { int mat1_bufi = tidx_to_bufi( ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides); - int mat2_bufi = tidx_to_bufi( - ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides); + int mat2_bufi; + if (mat2_is_transposed > 0) { + mat2_bufi = tidx_to_bufi( + ivec4(0, out_bufix.x, 0, 0), mat2_strides); + } else { + mat2_bufi = tidx_to_bufi( + ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides); + } + + int mat2_stride; + if (mat2_is_transposed > 0) { + mat2_stride = mat2_strides.x; + } else { + mat2_stride = mat2_strides.y; + } T sum = T(0.0); for (int i = 0; i < mat1_sizes.x; ++i) { sum += t_mat1[mat1_bufi] * t_mat2[mat2_bufi]; mat1_bufi += mat1_strides.x; - mat2_bufi += mat2_strides.y; + mat2_bufi += mat2_stride; } const int out_bufi = tidx_to_bufi(out_bufix, out_strides); diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index e2d6fc25519..1cba6de851c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -279,9 +279,12 @@ void linear(ComputeGraph& graph, const std::vector& args) { ValueRef weight = prepack_standard( graph, weight_data, graph.storage_type_of(out), utils::kWidthPacked); ValueRef mat2_is_transposed = graph.add_scalar(true); + if (graph.val_is_none(bias)) { return add_matmul_node(graph, input, weight, out, mat2_is_transposed); } else { + // Buffer implementation does not yet support biases + VK_CHECK_COND(!graph.is_buffer_storage(out)); return add_addmm_node( graph, bias, diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index 8ca9858d884..a852a30d087 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -77,6 +77,11 @@ void add_matmul_naive_buffer_node( graph.size_at(-2, out), graph.size_at(-3, out) * graph.size_at(-4, out)}; + int mat2_is_transposed_val = (mat2_is_transposed != kDummyValueRef && + graph.get_bool(mat2_is_transposed)) + ? 1 + : 0; + graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), @@ -96,7 +101,7 @@ void add_matmul_naive_buffer_node( graph.numel_ubo(out), }, // Specialization Constants - {}, + {mat2_is_transposed_val}, // Resizing Logic resize_matmul_node, {mat2_is_transposed})); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 304636f2fb5..85732d77011 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -126,8 +126,7 @@ def get_addmm_inputs(): ] -@register_test_suite("aten.linear.default") -def get_linear_inputs(): +def get_linear_texture_inputs(): MKN_list = common_MKN_list inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] @@ -142,9 +141,32 @@ def get_linear_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] + test_suite.test_name_suffix = "texture" + return test_suite + + +def get_linear_buffer_inputs(): + MKN_list = common_MKN_list + + inputs_list = [((M, K), (N, K), None) for M, K, N in MKN_list] + inputs_list += [((3, M, K), (N, K), None) for M, K, N in MKN_list] + + test_suite = VkTestSuite(inputs_list) + test_suite.dtypes = ["at::kFloat"] + test_suite.layouts = [ + "utils::kWidthPacked", + "utils::kChannelsPacked", + ] + test_suite.storage_types = ["utils::kBuffer"] + test_suite.test_name_suffix = "buffer" return test_suite +@register_test_suite("aten.linear.default") +def get_linear_test_suites(): + return [get_linear_texture_inputs(), get_linear_buffer_inputs()] + + @register_test_suite("aten._weight_int8pack_mm.default") def get_weight_int8pack_mm_inputs(): MKN_list = common_MKN_list