diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 1042c23bcb3..ea6601502f1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -260,9 +260,6 @@ void check_q_4w_linear_args( const int group_size_val = graph.extract_scalar(group_size); VK_CHECK_COND(K % group_size_val == 0); - VK_CHECK_COND(graph.packed_dim_of(mat1) == WHCN::kWidthDim); - VK_CHECK_COND(graph.packed_dim_of(out) == WHCN::kWidthDim); - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); VK_CHECK_COND(graph.has_standard_axis_map(out)); } @@ -320,13 +317,32 @@ void add_q_4w_linear_node( const uint32_t group_size_val = graph.extract_scalar(group_size); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + // Create temporary tensors to store the width packed versions of mat1 and out + TmpTensor mat1_tmp( + &graph, graph.sizes_of(mat1), graph.dtype_of(mat1), utils::kWidthPacked); + TmpTensor out_tmp( + &graph, graph.sizes_of(out), graph.dtype_of(out), utils::kWidthPacked); + if (storage_type == utils::kTexture3D) { + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = mat1_tmp; + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = out_tmp; + } + } + vkapi::ParamsBindList ubos({}); - ubos.append(graph.logical_limits_ubo(out)); - ubos.append(graph.sizes_ubo(mat1)); + ubos.append(graph.logical_limits_ubo(out_W_packed)); + ubos.append(graph.sizes_ubo(mat1_W_packed)); ubos.append(graph.strides_ubo(mat2)); ubos.append(graph.strides_ubo(scales_and_zeros)); - utils::uvec3 global_wg_size = graph.logical_limits_of(out); + utils::uvec3 global_wg_size = graph.logical_limits_of(out_W_packed); utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size); graph.execute_nodes().emplace_back(new DispatchNode( @@ -335,8 +351,9 @@ void add_q_4w_linear_node( global_wg_size, local_wg_size, // Inputs and Outputs - {{out, vkapi::MemoryAccessType::WRITE}, - {{mat1, mat2, scales_and_zeros}, vkapi::MemoryAccessType::READ}}, + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, mat2, scales_and_zeros}, + vkapi::MemoryAccessType::READ}}, // Shader params buffers ubos, // Specialization Constants @@ -344,6 +361,10 @@ void add_q_4w_linear_node( // Resizing Logic resize_q_4w_linear_node, {})); + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(out) != WHCN::kWidthDim) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } } void linear_weight_int4(