From 1cf46a14bd4bbc216752686836d9359cc9f2a246 Mon Sep 17 00:00:00 2001 From: Kush Rastogi Date: Tue, 15 Oct 2024 10:26:47 -0700 Subject: [PATCH] Width Packing Mat1 input for Quantized Linear (#6149) Summary: Width packing mat1 input for Quantized Linear as ASR model provides channel-packed matrix while operator does not support channel-packed yet. Reviewed By: nathanaelsee, jorgep31415 Differential Revision: D64065606 --- .../graph/ops/impl/QuantizedLinear.cpp | 43 +++++++++++++------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 206a4eafa36..838605f05f3 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -71,6 +71,17 @@ void add_q_8w_linear_node( const ValueRef q_mat2_data, const ValueRef scales_data, const ValueRef out) { + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + } ValueRef q_mat2 = prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); ValueRef scales = @@ -78,39 +89,45 @@ void add_q_8w_linear_node( std::string kernel_name = "q_8w_linear"; kernel_name.reserve(kShaderNameReserve); - add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); - add_dtype_suffix(kernel_name, graph.dtype_of(out)); - add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); vkapi::ParamsBindList ubos({}); - if (graph.is_buffer_storage(out)) { + if (graph.is_buffer_storage(out_W_packed)) { ubos.append( - {graph.sizes_ubo(out), - graph.strides_ubo(out), - graph.numel_ubo(out), - graph.sizes_ubo(mat1), + {graph.sizes_ubo(out_W_packed), + graph.strides_ubo(out_W_packed), + graph.numel_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed), graph.strides_ubo(mat1), graph.strides_ubo(q_mat2), graph.strides_ubo(scales)}); } else { - ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)}); + ubos.append( + {graph.logical_limits_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed)}); } graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - graph.create_global_wg_size(out), - graph.create_local_wg_size(out), + graph.create_global_wg_size(out_W_packed), + graph.create_local_wg_size(out_W_packed), // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, // Shader params buffers ubos, // Specialization Constants {}, // Resizing Logic resize_qlinear_node)); + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(out) != WHCN::kWidthDim) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } } void weight_int8pack_mm(