From daa1c0f9bc70c6731cee968a8ec66464b2b14bcd Mon Sep 17 00:00:00 2001 From: Kush Rastogi Date: Thu, 17 Oct 2024 11:21:45 -0700 Subject: [PATCH] Removing q_linear.h and adding tiled q_linear implementation (#5492) Summary: Removes q_linear.h and moves implementation directly to q_8w_linear.glsl Reviewed By: nathanaelsee Differential Revision: D61309097 --- backends/vulkan/partitioner/supported_ops.py | 1 + .../runtime/graph/ops/glsl/q_8w_linear.glsl | 60 ++++- .../graph/ops/glsl/q_8w_linear_optimized.glsl | 211 ++++++++++++++++++ .../graph/ops/glsl/q_8w_linear_optimized.yaml | 35 +++ .../vulkan/runtime/graph/ops/glsl/q_linear.h | 82 ------- .../graph/ops/impl/QuantizedLinear.cpp | 89 +++++++- 6 files changed, 394 insertions(+), 84 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml delete mode 100644 backends/vulkan/runtime/graph/ops/glsl/q_linear.h diff --git a/backends/vulkan/partitioner/supported_ops.py b/backends/vulkan/partitioner/supported_ops.py index da50719ba33..83dfb3b7686 100644 --- a/backends/vulkan/partitioner/supported_ops.py +++ b/backends/vulkan/partitioner/supported_ops.py @@ -84,6 +84,7 @@ def __contains__(self, op): exir_ops.edge.aten.addmm.default, exir_ops.edge.aten.linear.default, exir_ops.edge.et_vk.linear_weight_int4.default, + exir_ops.edge.aten._weight_int8pack_mm.default, # Reduction exir_ops.edge.aten._log_softmax.default, exir_ops.edge.aten._softmax.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl index a72df89b634..02cae3ed980 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl @@ -44,10 +44,38 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; // This header file must be defined after the layout descriptors have been // declared because the functions in the header assume some variables have been // declared as layout descriptors. -#include "q_linear.h" #ifdef USING_BUFFER +#ifndef FLOAT_T +#define FLOAT_T float +#endif + +FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { + const FLOAT_T scale = t_scales[out_idx.x]; + + FLOAT_T outval = FLOAT_T(0.0); + + // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) + int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; + // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 + // tensor is transposed + int qmat2_offset = out_idx.x * qmat2_strides.y; + + // TODO(ssjia): optimize memory access pattern by traversing K in inner loop + for (int i = 0; i < K; i++) { + const FLOAT_T mat1_val = t_mat1[mat1_offset]; + const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + + outval += mat1_val * mat2_val; + + mat1_offset++; + qmat2_offset++; + } + + return outval; +} + void main() { const int out_bufi = int(gl_GlobalInvocationID.x); if (out_bufi >= out_numel) { @@ -61,6 +89,36 @@ void main() { #else // USING_TEXTURE +VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { + ivec3 mat1_pos = ivec3(0, out_pos.yz); + ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0); + + VEC4_T outtex = VEC4_T(0); + + const ivec3 scales_pos = ivec3(out_pos.x, 0, 0); + const VEC4_T scales = load_texel(t_scales, scales_pos); + + for (int i = 0; i < K; i += 4) { + const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); + + const VEC4_T sums = VEC4_T( + dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z), + dot(mat1_tex, + load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w)); + + outtex += sums; + + mat1_pos.x++; + qmat2_pos.x++; + } + + return outtex; +} + void main() { const ivec3 out_pos = ivec3(gl_GlobalInvocationID); if (any(greaterThanEqual(out_pos, out_limits))) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl new file mode 100644 index 00000000000..dae2f7e3ab1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.glsl @@ -0,0 +1,211 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} +#define FLOAT_T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type(STORAGE)} + +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} + + +$if BATCH_MODE: + #define BATCH_MODE + +#define TILE_ROWS ${TILE_ROWS} +#define FOUR 4 + +// we avoid mat4 and vec4 usage here as they compile to much less efficient +// SPIR-V +struct FloatMatrix_2d { + float data[TILE_ROWS][FOUR]; +}; + +struct FloatMatrix_3d { + float data[TILE_ROWS][FOUR][FOUR]; +}; + +#ifdef BATCH_MODE + #define FloatMatrix FloatMatrix_3d +#else + #define FloatMatrix FloatMatrix_2d +#endif + +#include "indexing_utils.h" + +layout(std430) buffer; + +${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(1, "r", "t_mat1", DTYPE, STORAGE)} +${layout_declare_tensor(2, "r", "t_qmat2", "int8", STORAGE)} +${layout_declare_tensor(3, "r", "t_scales", DTYPE, STORAGE)} + +$if STORAGE == "buffer": + ${layout_declare_ubo(4, "ivec4", "out_sizes")} + ${layout_declare_ubo(5, "ivec4", "out_strides")} + ${layout_declare_ubo(6, "int", "out_numel")} + ${layout_declare_ubo(7, "ivec4", "mat1_sizes")} + ${layout_declare_ubo(8, "ivec4", "mat1_strides")} + ${layout_declare_ubo(9, "ivec4", "qmat2_strides")} + ${layout_declare_ubo(10, "ivec4", "scales_strides")} +$else: + ${layout_declare_ubo(4, "ivec3", "out_limits")} + ${layout_declare_ubo(5, "ivec4", "mat1_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +// This header file must be defined after the layout descriptors have been +// declared because the functions in the header assume some variables have been +// declared as layout descriptors. + +#ifdef USING_BUFFER + +#ifndef FLOAT_T +#define FLOAT_T float +#endif + +FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { + const FLOAT_T scale = t_scales[out_idx.x]; + + FLOAT_T outval = FLOAT_T(0.0); + + // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) + int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; + // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 + // tensor is transposed + int qmat2_offset = out_idx.x * qmat2_strides.y; + + // TODO(ssjia): optimize memory access pattern by traversing K in inner loop + for (int i = 0; i < K; i++) { + const FLOAT_T mat1_val = t_mat1[mat1_offset]; + const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; + + outval += mat1_val * mat2_val; + + mat1_offset++; + qmat2_offset++; + } + + return outval; +} + +void main() { + const int out_bufi = int(gl_GlobalInvocationID.x); + if (out_bufi >= out_numel) { + return; + } + + const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0); + + t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x); +} + +#else // USING_TEXTURE +FloatMatrix q_8w_linear_optimized(const ivec3 out_idx_tl) { + FloatMatrix results; + for (int i = 0; i < TILE_ROWS; i++) { + for (int j = 0; j < FOUR; j++) { +#ifdef BATCH_MODE + for (int k = 0; k < FOUR; k++) { + results.data[i][j][k] = 0.0f; + } +#else + results.data[i][j] = 0.0f; +#endif // BATCH_MODE + } + } + + VEC4_T im_mat1_partial_load[TILE_ROWS]; + VEC4_T im_mat2_partial_load[FOUR]; + +#ifdef BATCH_MODE + for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { + if (out_idx_tl.z + batch_idx >= out_limits.z) { + break; + } +#endif + for (int k = 0; k < mat1_sizes.x; k++) { + for (int r = 0; r < TILE_ROWS; r++) { + ivec3 mat1_pos = ivec3(k, out_idx_tl.y * TILE_ROWS + r, 0); +#ifdef BATCH_MODE + mat1_pos[2] = out_idx_tl.z + batch_idx; +#endif + + im_mat1_partial_load[r] = texelFetch(t_mat1, mat1_pos, 0); + } + + for (int r = 0; r < FOUR; ++r) { + ivec3 qmat2_pos = ivec3(k, FOUR * out_idx_tl.x + r, 0); + + im_mat2_partial_load[r] = texelFetch(t_qmat2, qmat2_pos, 0); + } + + vec4 scales = texelFetch(t_scales, ivec3(out_idx_tl.x, 0, 0), 0); + + // perform partial dot products and add partial result to results + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { + for (int out_col = 0; out_col < FOUR; out_col++) { +#ifdef BATCH_MODE + results.data[out_row][out_col][batch_idx] += +#else + results.data[out_row][out_col] += +#endif + dot(im_mat1_partial_load[out_row], + im_mat2_partial_load[out_col] * scales[out_col]); + } + } + } +#ifdef BATCH_MODE + } +#endif + return results; +} + +void main() { + const ivec3 out_idx = ivec3(gl_GlobalInvocationID); + if (any(greaterThanEqual(out_idx, out_limits))) { + return; + } + + FloatMatrix results = q_8w_linear_optimized(out_idx); + + ivec3 out_pos = ivec3( + out_idx.x, + out_idx.y * TILE_ROWS, +#ifdef BATCH_MODE + out_idx.z * 4 +#else + out_idx.z +#endif +); + + for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++, out_pos[1]++) { + out_pos.x = out_idx.x; + $if BATCH_MODE: + for (int idx_r = 0; idx_r < FOUR; idx_r++, out_pos[0]++) { + write_texel(t_out, out_pos, VEC4_T( + results.data[idx_c][idx_r][0], + results.data[idx_c][idx_r][1], + results.data[idx_c][idx_r][2], + results.data[idx_c][idx_r][3])); + } + $else: + write_texel(t_out, out_pos, VEC4_T( + results.data[idx_c][0], + results.data[idx_c][1], + results.data[idx_c][2], + results.data[idx_c][3])); + } +} + +#endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml new file mode 100644 index 00000000000..52bebf90125 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q_8w_linear_optimized.yaml @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q_8w_linear_optimized: + parameter_names_with_default_values: + DTYPE: float + STORAGE: texture3d + MAT1_PACKING: W_packed + MAT2_PACKING: W_packed + BATCH_MODE: false + TILE_ROWS: 4 + generate_variant_forall: + TILE_ROWS: + - VALUE: 4 + SUFFIX: tile_row_4 + - VALUE: 2 + SUFFIX: tile_row_2 + DTYPE: + - VALUE: float + - VALUE: half + STORAGE: + - VALUE: texture3d + - VALUE: buffer + shader_variants: + - NAME: q_8w_linear_optimized_W_packed_W_packed + - NAME: q_8w_linear_optimized_W_packed_H_packed + MAT2_PACKING: H_packed + - NAME: batch_q_8w_linear_optimized_W_packed_W_packed + BATCH_MODE: true + - NAME: batch_q_8w_linear_optimized_W_packed_H_packed + MAT2_PACKING: H_packed + BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/q_linear.h b/backends/vulkan/runtime/graph/ops/glsl/q_linear.h deleted file mode 100644 index f6de1e6dcf6..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/q_linear.h +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef Q_LINEAR_H -#define Q_LINEAR_H - -#include "indexing_utils.h" - -// The functions in this file assume that some variables have been defined as -// descriptors, such as t_mat1, t_qmat2, t_scales, etc. - -#ifdef USING_BUFFER - -#ifndef FLOAT_T -#define FLOAT_T float -#endif - -FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) { - const FLOAT_T scale = t_scales[out_idx.x]; - - FLOAT_T outval = FLOAT_T(0.0); - - // Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0) - int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z; - // Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2 - // tensor is transposed - int qmat2_offset = out_idx.x * qmat2_strides.y; - - // TODO(ssjia): optimize memory access pattern by traversing K in inner loop - for (int i = 0; i < K; i++) { - const FLOAT_T mat1_val = t_mat1[mat1_offset]; - const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale; - - outval += mat1_val * mat2_val; - - mat1_offset++; - qmat2_offset++; - } - - return outval; -} - -#else // USING_TEXTURE - -VEC4_T q_8w_linear(const ivec3 out_pos, const int K) { - ivec3 mat1_pos = ivec3(0, out_pos.yz); - ivec3 qmat2_pos = ivec3(0, out_pos.x * 4, 0); - - VEC4_T outtex = VEC4_T(0); - - const ivec3 scales_pos = ivec3(out_pos.x, 0, 0); - const VEC4_T scales = load_texel(t_scales, scales_pos); - - for (int i = 0; i < K; i += 4) { - const VEC4_T mat1_tex = load_texel(t_mat1, mat1_pos); - - const VEC4_T sums = VEC4_T( - dot(mat1_tex, load_texel(t_qmat2, qmat2_pos) * scales.x), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 1, 0)) * scales.y), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 2, 0)) * scales.z), - dot(mat1_tex, - load_texel(t_qmat2, qmat2_pos + ivec3(0, 3, 0)) * scales.w)); - - outtex += sums; - - mat1_pos.x++; - qmat2_pos.x++; - } - - return outtex; -} - -#endif // USING_BUFFER - -#endif // Q_LINEAR_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 4dd55be4692..5642976b7fe 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -11,7 +11,6 @@ #include #include - #include namespace vkcompute { @@ -130,6 +129,94 @@ void add_q_8w_linear_node( } } +void add_q_8w_linear_optimized_node( + ComputeGraph& graph, + const ValueRef mat1, + const ValueRef q_mat2_data, + const ValueRef scales_data, + const ValueRef out) { + auto viewFn = VK_GET_OP_FN("aten.view_copy.default"); + ValueRef mat1_W_packed = mat1; + ValueRef out_W_packed = out; + if (!graph.is_buffer_storage(out) && + graph.packed_dim_of(mat1) != WHCN::kWidthDim) { + // Ensure mat1 is width packed + mat1_W_packed = graph.add_tensor_like(mat1, utils::kWidthPacked); + viewFn(graph, {mat1, graph.add_none(), mat1_W_packed}); + // Ensure out is packed correctly + out_W_packed = graph.add_tensor_like(out, utils::kWidthPacked); + } + ValueRef q_mat2 = + prepack_if_tensor_ref(graph, q_mat2_data, utils::kWidthPacked); + ValueRef scales = + prepack_if_tensor_ref(graph, scales_data, utils::kWidthPacked); + + std::string kernel_name = "q_8w_linear_optimized"; + kernel_name.reserve(kShaderNameReserve); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(mat1_W_packed)); + add_packed_dim_suffix(kernel_name, graph.packed_dim_of(q_mat2)); + std::vector mat1_sizes = graph.sizes_of(mat1_W_packed); + const int mat1_dims = mat1_sizes.size(); + if (mat1_dims == 3) { + kernel_name = "batch_" + kernel_name; + } + if (mat1_sizes.at(mat1_dims - 2) < 8) { + kernel_name += "_tile_row_2"; + } else { + kernel_name += "_tile_row_4"; + } + + add_dtype_suffix(kernel_name, graph.dtype_of(out_W_packed)); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_W_packed)); + + vkapi::ParamsBindList ubos({}); + + utils::uvec3 global_size; + utils::uvec3 local_size; + if (graph.is_buffer_storage(out)) { + ubos.append( + {graph.sizes_ubo(out_W_packed), + graph.strides_ubo(out_W_packed), + graph.numel_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed), + graph.strides_ubo(mat1_W_packed), + graph.strides_ubo(q_mat2), + graph.strides_ubo(scales)}); + global_size = graph.create_global_wg_size(out_W_packed); + local_size = graph.create_local_wg_size(out_W_packed); + } else { + global_size = graph.logical_limits_of(out_W_packed); + ubos.append( + {graph.logical_limits_ubo(out_W_packed), + graph.sizes_ubo(mat1_W_packed)}); + if (mat1_sizes.at(mat1_dims - 2) < 8) { + global_size = global_size = utils::divup_vec(global_size, {1, 2, 1}); + } else { + global_size = utils::divup_vec(global_size, {1, 4, 1}); + } + local_size = {16, 3, 1}; + } + + graph.execute_nodes().emplace_back(new DispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + global_size, + local_size, + // Inputs and Outputs + {{out_W_packed, vkapi::MemoryAccessType::WRITE}, + {{mat1_W_packed, q_mat2, scales}, vkapi::MemoryAccessType::READ}}, + // Shader params buffers + ubos, + // Specialization Constants + {}, // spec_vars, + // Resizing Logic + resize_q_8w_linear_node)); + + if (!graph.is_buffer_storage(out)) { + viewFn(graph, {out_W_packed, graph.add_none(), out}); + } +} + void weight_int8pack_mm( ComputeGraph& graph, const std::vector& args) {