Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,46 +71,63 @@ void add_q_8w_linear_node(
const ValueRef q_mat2_data,
const ValueRef scales_data,
const ValueRef out) {
auto viewFn = VK_GET_OP_FN("aten.view_copy.default");
ValueRef mat1_W_packed = mat1;
ValueRef out_W_packed = out;
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(mat1) != WHCN::kWidthDim) {
// Ensure mat1 is width packed
mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked);
viewFn(graph, {mat1, graph.add_none(), mat1_W_packed});
// Ensure out is packed correctly
out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked);
}
ValueRef q_mat2 =
prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked);
ValueRef scales =
prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked);

std::string kernel_name = "q_8w_linear";
kernel_name.reserve(kShaderNameReserve);
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed));
add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed));
add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed));

vkapi::ParamsBindList ubos({});
if (graph.is_buffer_storage(out)) {
if (graph.is_buffer_storage(out_W_packed)) {
ubos.append(
{graph.sizes_ubo(out),
graph.strides_ubo(out),
graph.numel_ubo(out),
graph.sizes_ubo(mat1),
{graph.sizes_ubo(out_W_packed),
graph.strides_ubo(out_W_packed),
graph.numel_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed),
graph.strides_ubo(mat1),
graph.strides_ubo(q_mat2),
graph.strides_ubo(scales)});
} else {
ubos.append({graph.logical_limits_ubo(out), graph.sizes_ubo(mat1)});
ubos.append(
{graph.logical_limits_ubo(out_W_packed),
graph.sizes_ubo(mat1_W_packed)});
}

graph.execute_nodes().emplace_back(new DispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
graph.create_global_wg_size(out),
graph.create_local_wg_size(out),
graph.create_global_wg_size(out_W_packed),
graph.create_local_wg_size(out_W_packed),
// Inputs and Outputs
{{out, vkapi::MemoryAccessType::WRITE},
{{mat1, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
{{out_W_packed, vkapi::MemoryAccessType::WRITE},
{{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}},
// Shader params buffers
ubos,
// Specialization Constants
{},
// Resizing Logic
resize_qlinear_node));
if (!graph.is_buffer_storage(out) &&
graph.packed_dim_of(out) != WHCN::kWidthDim) {
viewFn(graph, {out_W_packed, graph.add_none(), out});
}
}

void weight_int8pack_mm(
Expand Down
Loading