From 58ab978fbbe49504d2d6668afe30f40a1270f22b Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 09:07:50 -0700 Subject: [PATCH 1/7] [ET-VK] Implemement linear_dq8ta_q4gsw Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- .github/workflows/pull.yml | 2 + backends/vulkan/custom_ops_lib.py | 30 ++ backends/vulkan/op_registry.py | 22 +- backends/vulkan/patterns/quantized_linear.py | 142 ++++++- .../vulkan/runtime/graph/ops/DispatchNode.cpp | 7 + .../graph/ops/glsl/addmm_naive_texture3d.glsl | 3 + .../ops/glsl/choose_qparams_per_row.glsl | 74 ++-- .../ops/glsl/choose_qparams_per_row.yaml | 7 +- .../ops/glsl/conv2d_q8csw_linear_tiled.glsl | 4 +- .../ops/glsl/linear_dq8ca_q4gsw_tiled.glsl | 151 +++++++ .../ops/glsl/linear_dq8ca_q4gsw_tiled.yaml | 28 ++ .../graph/ops/glsl/linear_fp_bias_load.glslh | 6 - .../graph/ops/glsl/linear_fp_input_tile.glslh | 9 +- .../ops/glsl/linear_fp_input_tile_load.glslh | 41 +- .../ops/glsl/linear_fp_output_tile.glslh | 31 +- .../linear_fp_output_tile_fp_compute.glslh | 58 +-- ...inear_fp_output_tile_fp_int4_compute.glslh | 68 ++-- ...inear_fp_output_tile_fp_int8_compute.glslh | 41 +- ...ear_fp_output_tile_int8_int4_compute.glslh | 83 ++++ ...ear_fp_output_tile_int8_int8_compute.glslh | 115 +----- .../glsl/linear_fp_output_tile_store.glslh | 37 +- .../linear_fp_per_out_channel_params.glslh | 5 - .../glsl/linear_fp_weight_scales_load.glslh | 12 - .../ops/glsl/linear_fp_weight_tile.glslh | 64 +-- .../ops/glsl/linear_int4_weight_block.glslh | 36 +- .../ops/glsl/linear_int4_weight_tile.glslh | 59 +-- .../glsl/linear_int4_weight_tile_load.glslh | 44 +- .../ops/glsl/linear_int8_input_block.glslh | 31 +- .../glsl/linear_int8_input_scales_zps.glslh | 59 +++ .../linear_int8_input_scales_zps_load.glslh | 27 ++ .../glsl/linear_int8_input_sums_load.glslh | 33 ++ .../glsl/linear_int8_input_sums_store.glslh | 26 ++ .../ops/glsl/linear_int8_input_tile.glslh | 9 +- .../glsl/linear_int8_input_tile_load.glslh | 39 +- .../ops/glsl/linear_int8_weight_block.glslh | 32 +- .../ops/glsl/linear_int8_weight_tile.glslh | 15 +- .../glsl/linear_int8_weight_tile_load.glslh | 44 +- .../ops/glsl/linear_int_accumulator.glslh | 46 +++ .../linear_int_per_in_channel_params.glslh | 22 + .../linear_int_per_out_channel_params.glslh | 7 +- .../glsl/linear_int_weight_sums_load.glslh | 21 +- .../graph/ops/glsl/linear_q4gsw_coop.glsl | 22 +- .../graph/ops/glsl/linear_q4gsw_coop.yaml | 15 + .../graph/ops/glsl/linear_q8csw_tiled.glsl | 4 +- ...ntize_and_pack_linear_input_with_sums.glsl | 122 ++++++ ...ntize_and_pack_linear_input_with_sums.yaml | 30 ++ .../runtime/graph/ops/impl/ChooseQParams.cpp | 16 +- .../graph/ops/impl/QuantizedLinear.cpp | 383 ++++++++++++++++-- .../custom_ops/choose_qparams_per_row.cpp | 4 +- .../vulkan/test/custom_ops/q4gsw_linear.cpp | 220 +++++++++- backends/vulkan/test/custom_ops/utils.cpp | 59 +++ backends/vulkan/test/custom_ops/utils.h | 8 + backends/vulkan/test/test_vulkan_delegate.py | 82 ++++ backends/vulkan/utils.py | 25 +- 54 files changed, 1815 insertions(+), 765 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int_accumulator.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/linear_int_per_in_channel_params.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 37c6623ca97..455f427b386 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -933,6 +933,8 @@ jobs: PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d + ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_conv2d + ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row # Run e2e testing for selected operators. More operators will be tested via this # route in the future. diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 4312971f5f1..336ca2117a0 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -293,6 +293,19 @@ def linear_q4gsw( return out +def linear_dq8ca_q4gsw( + x: torch.Tensor, + input_scale: torch.Tensor, + input_zero_point: torch.Tensor, + weights: torch.Tensor, + weight_sums: torch.Tensor, + weight_scales: torch.Tensor, + group_size: int, + bias: Optional[torch.Tensor] = None, +): + return linear_q4gsw(x, weights, weight_scales, group_size) + + name = "linear_q4gsw" lib.define( f""" @@ -307,6 +320,23 @@ def linear_q4gsw( lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd") linear_qc4w_op = getattr(getattr(torch.ops, namespace), name) +name = "linear_dq8ca_q4gsw" +lib.define( + f""" + {name}( + Tensor input, + Tensor input_scales, + Tensor input_zp, + Tensor weights, + Tensor weight_sums, + Tensor weight_scales, + int group_size, + Tensor? bias = None) -> Tensor + """ +) +lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd") +linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name) + ######################## ## linear_qta8a_qga4w ## ######################## diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 1b74ef1ac65..2fd5bbef250 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -190,8 +190,8 @@ def register_torchao_choose_qparams_affine(): return OpFeatures( inputs_storage=utils.CONTIGUOUS_ANY, outputs_storage=[ - utils.CONTIGUOUS_BUFFER, # scales - utils.CONTIGUOUS_BUFFER, # zero_points + utils.WIDTH_PACKED_TEXTURE, # scales + utils.WIDTH_PACKED_TEXTURE, # zero_points ], supports_resize=True, ) @@ -341,7 +341,23 @@ def register_quantized_linear_ops(): return OpFeatures( inputs_storage=utils.CONTIGUOUS_ANY, supports_prepacking=True, - supports_resize=False, + ) + + +@update_features(exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default) +def register_linear_dqa_qw_ops(): + return OpFeatures( + inputs_storage=[ + utils.CONTIGUOUS_ANY, # input + utils.WIDTH_PACKED_TEXTURE, # input_scale + utils.WIDTH_PACKED_TEXTURE, # input_zero_point + utils.NO_STORAGE, # weight (prepacked) + utils.NO_STORAGE, # weight_sums (prepacked) + utils.NO_STORAGE, # weight_scales (prepacked) + utils.NO_STORAGE, # group_size (scalar) + utils.NO_STORAGE, # bias (prepacked) + ], + supports_prepacking=True, ) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index ee1c7ee2d2a..fd2708327c9 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator + from typing import Optional import executorch.backends.vulkan.utils as utils @@ -114,11 +116,23 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # If input is not quantized, then we are done if self.quantize_input_node is None: + raise Exception("Input is not quantized") self.match_found = True return - self.input_scales_node = self.quantize_input_node.args[1] - self.input_zeros_node = self.quantize_input_node.args[2] + scales_arg_idx = 1 + zeros_arg_idx = 2 + + # torchao op has a slightly different function schema + if ( + self.quantize_input_node.target + == exir_ops.edge.torchao.quantize_affine.default + ): + scales_arg_idx = 2 + zeros_arg_idx = 3 + + self.input_scales_node = self.quantize_input_node.args[scales_arg_idx] + self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx] assert dq_node is not None self.all_nodes.extend( @@ -164,6 +178,24 @@ def is_input_static_per_tensor_quantized(self) -> bool: # are scalars. return isinstance(self.input_scales_node, float) + def is_input_dynamic_perchannel_quantized(self) -> bool: + if self.quantize_input_node is None: + return False + + # For dynamic quantization, input scale node should be a getitem operator + # retrieving the output of a choose_qparams op + if self.input_scales_node.target != operator.getitem: + return False + + # The getitem node should be retrieving from a choose_qparams op + if not utils.is_choose_qparams_node(self.input_scales_node.args[0]): + return False + + scales_shape = self.input_scales_node.meta["val"].shape + input_shape = self.fp_input_node.meta["val"].shape + + return input_shape[-2] == scales_shape[-1] + linear_anchor_nodes = { exir_ops.edge.aten.linear.default, @@ -230,6 +262,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor: return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2] +def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int): + """ + Compute the sum of weights per quantization group. + + Args: + weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8. + group_size (int): Number of input channels per quantization group. + + Returns: + torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size. + """ + out_channels, in_channels = weight_tensor.shape + num_groups = in_channels // group_size + # Reshape to [out_channels, num_groups, group_size] + reshaped = weight_tensor.view(out_channels, num_groups, group_size) + # Sum over group_size dimension to get [out_channels, num_groups] + sums = reshaped.sum(dim=2) + # Transpose to [num_groups, out_channels] + sums = sums.transpose(0, 1).contiguous() + # Pad out_channels dim (dim=1) to be a multiple of 8 if needed + out_channels = sums.shape[1] + if out_channels % 8 != 0: + num_pad = 8 - (out_channels % 8) + sums = F.pad(sums, (0, num_pad)) + + return sums.to(torch.int32).contiguous() + + ## ## Pattern Replacement ## @@ -281,6 +341,73 @@ def make_linear_q4gsw_op( match.output_node.replace_all_uses_with(linear_q4gsw_node) +def make_linear_dq8ca_q4gsw_op( + ep: ExportedProgram, + graph_module: torch.fx.GraphModule, + match: QuantizedLinearMatch, + weight_tensor: torch.Tensor, + weight_scales_tensor: torch.Tensor, +): + num_groups = weight_scales_tensor.shape[-1] + in_channels = weight_tensor.shape[-1] + group_size = in_channels // num_groups + + # Compute per quant group sums before packing the weight tensor + sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size) + + weight_tensor = pack_4bit_weight_tensor(weight_tensor) + # Use this function for convenience to update the state dict with the packed + # weight tensor. Alignment will already have been done in the above function. + weight_tensor = utils.align_width_and_update_state_dict( + ep, match.weight_node, weight_tensor, align_to=1, force_update=True + ) + + # Also transpose the weight scales tensor to shape [num_groups, N] + weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous() + utils.align_width_and_update_state_dict( + ep, + match.weight_scales_node, + weight_scales_tensor, + align_to=1, + force_update=True, + ) + + first_graph_node = list(graph_module.graph.nodes)[0] + with graph_module.graph.inserting_before(first_graph_node): + weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) + # Pre-compute the weight sums which are needed to apply activation zero point + # when using integer accumulation. + sums_name = weight_tensor_name + "_sums" + # Sanitize the name + sums_name = sums_name.replace(".", "_") + + weight_sums_node = create_constant_placeholder( + exp_program=ep, + graph=graph_module.graph, + kind=InputKind.CONSTANT_TENSOR, + name=sums_name, + data=sum_per_quant_group, + ) + + with graph_module.graph.inserting_before(match.output_node): + qlinear_node = graph_module.graph.create_node( + "call_function", + exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default, + args=( + match.fp_input_node, + match.input_scales_node, + match.input_zeros_node, + match.weight_node, + weight_sums_node, + match.weight_scales_node, + group_size, + ), + ) + + qlinear_node.meta["val"] = match.output_node.meta["val"] + match.output_node.replace_all_uses_with(qlinear_node) + + def make_linear_q8ta_q8csw_custom_op( ep: ExportedProgram, graph_module: torch.fx.GraphModule, @@ -351,13 +478,20 @@ def replace_quantized_linear_patterns( and match.is_weight_pergroup_quantized() and utils.is_in_4bit_range(weight_tensor) ): + raise Exception("Unsupported pattern") make_linear_q4gsw_op( ep, graph_module, match, weight_tensor, weight_scales_tensor ) + elif ( + match.is_input_dynamic_perchannel_quantized() + and match.is_weight_pergroup_quantized() + and utils.is_in_4bit_range(weight_tensor) + ): + make_linear_dq8ca_q4gsw_op( + ep, graph_module, match, weight_tensor, weight_scales_tensor + ) elif ( match.is_input_static_per_tensor_quantized() and match.is_weight_perchannel_quantized() ): make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor) - - # No-op for unsupported quant patterns diff --git a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp index 898a3415b7e..d1add8227de 100644 --- a/backends/vulkan/runtime/graph/ops/DispatchNode.cpp +++ b/backends/vulkan/runtime/graph/ops/DispatchNode.cpp @@ -44,6 +44,13 @@ void DispatchNode::encode(ComputeGraph* graph) { if (!shader_) { return; } + + // If any global wg size element is 0, then skip encoding this shader + if (global_workgroup_size_[0] == 0 || global_workgroup_size_[1] == 0 || + global_workgroup_size_[2] == 0) { + return; + } + api::Context* const context = graph->context(); vkapi::PipelineBarrier pipeline_barrier{}; diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl index a4ed494fe6d..bd210e210ce 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_naive_texture3d.glsl @@ -136,6 +136,9 @@ vec4 matmul_naive_k_dim_packed_row_dim_packed(const ivec3 out_lpos) { const vec4 mat1_tex = texelFetch(mat1_tensor, mat1_pos, 0); for (int r = 0; r < 4; ++r) { + if (4 * i + r >= mat2_sizes.y) { + continue; + } // On-demand construction of mat2_pos appears to provide the lowest // latency. Surprisingly, this doesn't translate to mat1_pos. ivec3 mat2_pos = ivec3(0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl index 653b0a251c0..c95bb66f164 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.glsl @@ -12,8 +12,8 @@ #define VEC4_T ${texel_load_type(DTYPE, STORAGE)} #define T ${texel_load_component_type(DTYPE, STORAGE)} -#define NUM_OUTPUTS_PER_WG ${NUM_OUTPUTS_PER_WG} -#define NUM_WORKERS_PER_OUTPUT ${NUM_WORKERS_PER_OUTPUT} +#define NUM_OUTPUTS_PER_WG 1 +#define NUM_WORKERS_PER_OUTPUT 64 // Maximum total threads in a work group #define MAX_THREADS 256 @@ -27,8 +27,8 @@ layout(std430) buffer; #include "common.glslh" -${layout_declare_tensor(B, "w", "t_scales", "float", "buffer")} -${layout_declare_tensor(B, "w", "t_zps", "int", "buffer")} +${layout_declare_tensor(B, "w", "t_scales", DTYPE, "texture3d")} +${layout_declare_tensor(B, "w", "t_zps", "int8", "texture3d")} ${layout_declare_tensor(B, "r", "t_input", DTYPE, STORAGE, is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "input_sizes")} @@ -40,6 +40,8 @@ layout(push_constant) uniform PushConstants { int quant_max; }; +#extension GL_EXT_debug_printf : enable + // Shared memory for cooperative min/max finding shared T shared_min[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; shared T shared_max[NUM_OUTPUTS_PER_WG][NUM_WORKERS_PER_OUTPUT]; @@ -52,7 +54,7 @@ void calculate_scale_and_zero_point( int qmin, int qmax, out float scale, - out int8_t zero_point) { + out int zero_point) { // Extend the [min, max] interval to ensure it contains 0 min_val = min(min_val, 0.0); @@ -102,29 +104,21 @@ void calculate_scale_and_zero_point( nudged_zero_point = int(round(initial_zero_point)); } - zero_point = int8_t(nudged_zero_point); + zero_point = nudged_zero_point; } -#ifdef USING_BUFFER - VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) { +#ifdef USING_BUFFER return t_input[(y * ntexels_x) + x4]; -} - -#else // USING_TEXTURE - -VEC4_T load_input_x4(const int x4, const int y, const int ntexels_x) { +#else return texelFetch(t_input, ivec3(x4, y, 0), 0); +#endif } -#endif // USING_BUFFER - -void main() { +void find_min_max_for_row(const int output_y) { const int worker_id = int(gl_LocalInvocationID.x); const int output_id = int(gl_LocalInvocationID.y); - const int output_y = int(gl_GlobalInvocationID.y); - if (output_y >= input_sizes.y) { return; } @@ -167,18 +161,42 @@ void main() { memoryBarrierShared(); barrier(); } +} - // Only first thread will write out result - if (worker_id == 0) { - local_min = shared_min[output_id][0]; - local_max = shared_max[output_id][0]; +void main() { + const int worker_id = int(gl_LocalInvocationID.x); + const int output_id = int(gl_LocalInvocationID.y); + + const int output_y4 = int(gl_GlobalInvocationID.y); + const int output_y = mul_4(output_y4); + + + VEC4_T scales_out = VEC4_T(0.0); + ivec4 zps_out = ivec4(0); + + int limit = min(input_sizes.y - output_y, 4); + for (int i = 0; i < limit; i++) { + find_min_max_for_row(output_y + i); + + // Only the first thread in the work group will compute the result + if (worker_id == 0) { + float local_min = shared_min[output_id][0]; + float local_max = shared_max[output_id][0]; - float scale; - int8_t zero_point; - calculate_scale_and_zero_point( - local_min, local_max, quant_min, quant_max, scale, zero_point); + float scale; + int zero_point; - t_scales[output_y] = scale; - t_zps[output_y] = zero_point; + calculate_scale_and_zero_point( + local_min, local_max, quant_min, quant_max, scale, zero_point); + + scales_out[i] = scale; + zps_out[i] = zero_point; + } + } + + if (worker_id == 0) { + imageStore(t_scales, ivec3(output_y4, 0, 0), scales_out); + imageStore(t_zps, ivec3(output_y4, 0, 0), zps_out); } + } diff --git a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml index 3608f7193bf..1594bb574bd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/choose_qparams_per_row.yaml @@ -8,8 +8,6 @@ choose_qparams_per_row: parameter_names_with_default_values: DTYPE: float STORAGE: texture3d - NUM_OUTPUTS_PER_WG: 1 - NUM_WORKERS_PER_OUTPUT: 64 generate_variant_forall: STORAGE: - VALUE: texture3d @@ -17,7 +15,4 @@ choose_qparams_per_row: DTYPE: - VALUE: float shader_variants: - - NAME: choose_qparams_per_row_o1w64 - - NAME: choose_qparams_per_row_o4w16 - NUM_OUTPUTS_PER_WG: 4 - NUM_WORKERS_PER_OUTPUT: 16 + - NAME: choose_qparams_per_row diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl index e2b239800a8..3615d423230 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_q8csw_linear_tiled.glsl @@ -110,10 +110,10 @@ void main() { FPPerOutChannelParams bias_tile; load_bias_tile(bias_tile, n4); - apply_scales_and_biases(out_tile, weight_scales_tile, bias_tile); + apply_weight_scales_and_biases(out_tile, weight_scales_tile, bias_tile); } else { - apply_scales(out_tile, weight_scales_tile); + apply_weight_scales(out_tile, weight_scales_tile); } write_im2col_tile_as_image(out_tile, n4, m); diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl new file mode 100644 index 00000000000..c7df1b429c7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.glsl @@ -0,0 +1,151 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} +#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} + +$if IO_STORAGE == "buffer": + #define OUTPUT_BUFFER + #define INPUT_BUFFER +$if PACKED_INT8_INPUT_STORAGE == "buffer": + #define PACKED_INT8_INPUT_BUFFER +$if WEIGHT_STORAGE == "buffer": + #define WEIGHT_BUFFER + +#define TILE_N8 ${TILE_N8} + +#define TILE_M4 ${TILE_M4} +#define TILE_K4 ${TILE_K4} +#define TILE_N4 ${TILE_N8 * 2} + +#define TILE_M ${TILE_M4 * 4} +#define TILE_K ${TILE_K4 * 4} +#define TILE_N ${TILE_N8 * 8} + +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INT8_INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")} +${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} + +${layout_declare_ubo(B, "ivec4", "output_sizes")} +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "apply_bias", "0")} +${layout_declare_spec_const(C, "int", "K4_per_group", "0")} + +#include "linear_fp_input_tile_load.glslh" +#include "linear_int8_input_tile_load.glslh" +#include "linear_int8_input_scales_zps_load.glslh" +#include "linear_int4_weight_tile_load.glslh" +#include "linear_int_weight_sums_load.glslh" +#include "linear_fp_weight_scales_load.glslh" +#include "linear_int8_input_sums_load.glslh" +#include "linear_fp_output_tile_int8_int8_compute.glslh" +#include "linear_fp_output_tile_int8_int4_compute.glslh" +#include "linear_fp_output_tile_fp_int4_compute.glslh" +#include "linear_fp_output_tile_fp_compute.glslh" +#include "linear_fp_output_tile_store.glslh" +#include "linear_fp_bias_load.glslh" + +void main() { + const int out_tile_x = int(gl_GlobalInvocationID.x); + const int out_tile_y = int(gl_GlobalInvocationID.y); + + const int n = out_tile_x * TILE_N; + const int m = out_tile_y * TILE_M; + + const int n8 = div_8(n); + const int n4 = div_4(n); + const int m4 = div_4(m); + + if (n >= output_sizes.x || m >= output_sizes.y) { + return; + } + + const int M = input_sizes.y; + const int K4 = div_up_4(input_sizes.x); + const int M4 = div_up_4(M); + const int N4 = div_up_4(output_sizes.x); + const int N8 = div_up_8(output_sizes.x); + + FPOutTile out_tile; + initialize(out_tile); + + Int32Accum out_accum; + initialize(out_accum); + + Int8InputTile int8_in_tile; + Int4WeightTile int4_weight_tile; + + Int8InputScales input_scales; + Int8InputZeroPoints input_zps; + load_int8_input_scales_and_zps(input_scales, input_zps, m4); + + FPPerOutChannelParams weight_scales_tile; + IntPerOutChannelParams weight_sums_tile; + + IntPerInChannelParams int8_input_sums_tile; + + const int num_groups = K4 / K4_per_group; + + for (int group_i = 0; group_i < num_groups; ++group_i) { + // Reset int accumulator + initialize(out_accum); + for (int k4_inner = 0; k4_inner < K4_per_group; k4_inner++) { + const int k4 = group_i * K4_per_group + k4_inner; + + load_int8_input_tile(int8_in_tile, k4, m4, K4); + load_int4_weight_tile(int4_weight_tile, k4, n8, K4); + // load_int4_weight_tile(int4_weight_tile, n8, k4, N8); + + int_accumulate_with_int4_weight( + out_accum, int8_in_tile, int4_weight_tile); + } + + load_weight_scales_tile_for_group(weight_scales_tile, n4, group_i, N4); + load_weight_sums_tile_for_group(weight_sums_tile, n4, group_i, N4); + load_int8_input_sums_tile_for_group(int8_input_sums_tile, m4, group_i, M4); + + const int group_size = mul_4(K4_per_group); + + // // Update output tile with accumulated values + // accumulate_out_tile_with_int_accum_from_int4_weights_test( + // out_tile, + // out_accum); + + accumulate_out_tile_with_int_accum_from_int4_weights( + out_tile, + out_accum, + int8_input_sums_tile, + input_scales, + input_zps, + weight_sums_tile, + weight_scales_tile, + group_size); + } + + write_output_tile_with_checks(out_tile, n4, m, N4, M); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml new file mode 100644 index 00000000000..cb9cdc4a046 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_dq8ca_q4gsw_tiled.yaml @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +linear_dq8ca_q4gsw_tiled: + parameter_names_with_default_values: + DTYPE: float + IO_STORAGE: texture3d + WEIGHT_STORAGE: texture2d + PACKED_INT8_INPUT_STORAGE: buffer + TILE_M4: 1 + TILE_K4: 1 + TILE_N8: 1 + generate_variant_forall: + DTYPE: + - VALUE: float + shader_variants: + - NAME: linear_dq8ca_q4gsw_tiled_texture3d_texture2d + - NAME: linear_dq8ca_q4gsw_tiled_texture3d_buffer + WEIGHT_STORAGE: buffer + - NAME: linear_dq8ca_q4gsw_tiled_buffer_texture2d + IO_STORAGE: buffer + WEIGHT_STORAGE: texture2d + - NAME: linear_dq8ca_q4gsw_tiled_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh index f3d32be8b3d..c1ab8999420 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_bias_load.glslh @@ -16,15 +16,9 @@ VEC4_T load_bias_x4(const int n4) { } void load_bias_tile(out FPPerOutChannelParams bias, const int n4_start) { -#if TILE_N4 == 1 - bias.data[0] = load_bias_x4(n4_start); - -#else [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { bias.data[n4] = load_bias_x4(n4_start + n4); } - -#endif } #endif // LINEAR_FP_BIAS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh index 68eee57a132..72b5bdb812e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile.glslh @@ -10,12 +10,9 @@ #define LINEAR_FP_INPUT_TILE_GLSLH /* - * Defines the FPInputTile struct, which is used to represent a tile of the - * input matrix of a matrix multiplication operation. - * - * Settings: - * - TILE_M: number of rows in the tile - * - TILE_K4: number of (groups of 4) columns in the tile + * Macro Settings: + * - TILE_M + * - TILE_K4 */ #extension GL_EXT_control_flow_attributes : require diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh index 6697003935f..358379b3efd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_input_tile_load.glslh @@ -7,14 +7,11 @@ */ /* - * Defines functions to load a FPInputTile from input buffer/texture. + * Assume the following variables are defined in the shader layout: + * - t_input * - * Requires: - * - t_input to be declared in the shader layout (input buffer/texture) - * - * Settings: - * - INPUT_BUFFER to indicate input resource is a buffer, otherwise texture is - * assumed. + * Macro Settings: + * - INPUT_BUFFER */ #ifndef LINEAR_FP_INPUT_TILE_LOAD_GLSLH @@ -24,58 +21,33 @@ #include "linear_fp_input_tile.glslh" -#ifdef INPUT_BUFFER - VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { +#ifdef INPUT_BUFFER return t_input[(m * ntexels_k) + k4]; -} - #else - -VEC4_T load_input_x4(const int k4, const int m, const int ntexels_k) { return texelFetch(t_input, ivec3(k4, m, 0), 0); +#endif } -#endif // INPUT_BUFFER - -// To be used if (M - m_start >= TILE_M) || (K4 - k4_start >= TILE_K4) void load_input_tile_no_checks( out FPInputTile in_tile, const int k4_start, const int m_start, const int K4, const int M) { -#if TILE_K4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { in_tile.data[m][k4] = load_input_x4(k4_start + k4, m_start + m, K4); } } -#endif } -// To be used if near tensor boundaries void load_input_tile_with_checks( out FPInputTile in_tile, const int k4_start, const int m_start, const int K4, const int M) { -#if TILE_K4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - if (m_start + m < M) { - in_tile.data[m][0] = load_input_x4(k4_start, m_start + m, K4); - } else { - in_tile.data[m][0] = VEC4_T(0.0); - } - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { if (m_start + m < M && k4_start + k4 < K4) { @@ -85,7 +57,6 @@ void load_input_tile_with_checks( } } } -#endif } #endif // LINEAR_FP_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh index dd571229a9c..ca466447084 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile.glslh @@ -7,12 +7,9 @@ */ /* - * Defines the FPOutTile struct, which is used to represent a tile of the output - * matrix of a matrix multiplication operation. - * - * Settings: - * - TILE_M: number of rows in the output tile - * - TILE_N4: number of (groups of 4) columns in the output tile + * Macro Settings: + * - TILE_M + * - TILE_N4 */ #ifndef LINEAR_FP_OUTPUT_TILE_GLSLH @@ -25,33 +22,11 @@ struct FPOutTile { }; void initialize(out FPOutTile out_tile) { -#if TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - out_tile.data[m][0] = VEC4_T(0); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { out_tile.data[m][n4] = VEC4_T(0); } } -#endif -} - -void add(inout FPOutTile out_tile, const FPOutTile other_out_tile) { -#if TILE_M > 1 && TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - out_tile.data[m][0] += other_out_tile.data[m][0]; - } - -#else - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - out_tile.data[m][n4] += other_out_tile.data[m][n4]; - } - } -#endif } #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh index ee50ad87f74..01b3c762e39 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_compute.glslh @@ -21,41 +21,12 @@ #include "linear_fp_per_out_channel_params.glslh" #include "linear_fp_weight_tile.glslh" -/* - * Accumulates floating point input tile and floating point weight tile into - * floating point output tile. - */ void fp_accumulate_with_fp_weight( inout FPOutTile accum, FPInputTile in_tile, FPWeightTile w_tile) { -#if TILE_N4 == 1 && TILE_K4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - accum.data[m][0] = - fma(VEC4_T(in_tile.data[m][0][0]), - w_tile.data[mul_4(0)][0], - accum.data[m][0]); - - accum.data[m][0] = - fma(VEC4_T(in_tile.data[m][0][1]), - w_tile.data[mul_4(0) + 1][0], - accum.data[m][0]); - - accum.data[m][0] = - fma(VEC4_T(in_tile.data[m][0][2]), - w_tile.data[mul_4(0) + 2][0], - accum.data[m][0]); - - accum.data[m][0] = - fma(VEC4_T(in_tile.data[m][0][3]), - w_tile.data[mul_4(0) + 3][0], - accum.data[m][0]); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - const int n = mul_4(n4); [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { accum.data[m][n4] = fma(VEC4_T(in_tile.data[m][k4][0]), @@ -79,48 +50,27 @@ void fp_accumulate_with_fp_weight( } } } - -#endif } -/* - * Applies per output channel weight scales to the output tile. - */ -void apply_scales(inout FPOutTile tile, const FPPerOutChannelParams scales) { -#if TILE_M > 1 && TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - tile.data[m][0] = tile.data[m][0] * scales.data[0]; - } - -#else +void apply_weight_scales( + inout FPOutTile tile, + const FPPerOutChannelParams scales) { [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { tile.data[m][n4] = tile.data[m][n4] * scales.data[n4]; } } -#endif } -/* - * Applies per output channel weight scales and per output channel biases to the - * output tile. - */ -void apply_scales_and_biases( +void apply_weight_scales_and_biases( inout FPOutTile tile, const FPPerOutChannelParams scales, const FPPerOutChannelParams bias) { -#if TILE_M > 1 && TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - tile.data[m][0] = tile.data[m][0] * scales.data[0] + bias.data[0]; - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { tile.data[m][n4] = tile.data[m][n4] * scales.data[n4] + bias.data[n4]; } } -#endif } void accumulate_out_tile_with_out_tile( diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh index 0606759e393..1f99b6672c0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int4_compute.glslh @@ -6,10 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - * Defines functions to compute a FPOutTile using fp input and weight tiles. - */ - #ifndef LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH #define LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH @@ -21,8 +17,6 @@ #include "linear_fp_per_out_channel_params.glslh" #include "linear_int4_weight_tile.glslh" -// Unpacks a int containing 4 packed 8-bit integers into a vec4 containing each -// of the 4 unpacked 8-bit integers. VEC4_T unpack_packed_4xint4(const int int8x4, const int n4_group) { return VEC4_T( extract_4bit_from_packed_int_le(int8x4, n4_group + 0), @@ -54,39 +48,43 @@ void fp_accumulate_with_int4_weight( // (n, k), (n, k + 1), (n, k + 2), (n, k + 3), // (n + 4, k), (n + 4, k + 1), (n + 4, k + 2), (n + 4, k + 3) VEC4_T weight_texels[2]; -#if TILE_K4 == 1 && TILE_N8 == 1 - [[unroll]] for (int k = 0; k < 4; ++k) { - const int base_col_1 = mul_2(k); - const int base_col_2 = base_col_1 + 1; - weight_texels[0] = VEC4_T( - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 0), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 1), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 2), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_1, 3)); - weight_texels[1] = VEC4_T( - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 0), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 1), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 2), - extract_4bit_from_weight_block(w_tile.data[0][0], base_col_2, 3)); + [[unroll]] for (int n8 = 0; n8 < TILE_N8; ++n8) { + const int n4 = mul_2(n8); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int k4i = 0; k4i < 4; ++k4i) { + const int base_col_1 = mul_2(k4i); + const int base_col_2 = base_col_1 + 1; + weight_texels[0] = VEC4_T( + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_1, 0), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_1, 1), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_1, 2), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_1, 3)); + weight_texels[1] = VEC4_T( + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_2, 0), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_2, 1), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_2, 2), + extract_4bit_from_weight_block(w_tile.data[k4][n8], base_col_2, 3)); - weight_texels[0] = - fma(weight_texels[0], scales_tile.data[0], zeros_tile.data[0]); - weight_texels[1] = - fma(weight_texels[1], scales_tile.data[1], zeros_tile.data[1]); + weight_texels[0] = + fma(weight_texels[0], scales_tile.data[n4], zeros_tile.data[n4]); + weight_texels[1] = + fma(weight_texels[1], + scales_tile.data[n4 + 1], + zeros_tile.data[n4 + 1]); - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - accum.data[m][0] = fma( - VEC4_T(in_tile.data[m][0][k]), weight_texels[0], accum.data[m][0]); - accum.data[m][1] = fma( - VEC4_T(in_tile.data[m][0][k]), weight_texels[1], accum.data[m][1]); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][k4i]), + weight_texels[0], + accum.data[m][n4]); + accum.data[m][n4 + 1] = + fma(VEC4_T(in_tile.data[m][k4][k4i]), + weight_texels[1], + accum.data[m][n4 + 1]); + } + } } } - -#else - // TODO(ssjia): Implement generic case - not implemented - -#endif } #endif // LINEAR_FP_OUTPUT_TILE_FP_INT4_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh index b2ab64a1573..cb0ded2f9ec 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_fp_int8_compute.glslh @@ -6,10 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - * Defines functions to compute a FPOutTile using fp input and weight tiles. - */ - #ifndef LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH #define LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH @@ -20,8 +16,6 @@ #include "linear_fp_output_tile.glslh" #include "linear_int8_weight_tile.glslh" -// Unpacks a int containing 4 packed 8-bit integers into a vec4 containing each -// of the 4 unpacked 8-bit integers. VEC4_T unpack_packed_4xint8(int int8x4) { return VEC4_T( extract_8bit_from_packed_int_le(int8x4, 0), @@ -43,26 +37,25 @@ void fp_accumulate_with_int8_weight( // -> gives packed integer containing the 4x 8-bit quantized values at index // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) VEC4_T weight_texel; -#if TILE_K4 == 1 && TILE_N4 == 1 - [[unroll]] for (int k = 0; k < 4; ++k) { - // Unpack one column of weights - weight_texel = VEC4_T( - extract_8bit_from_packed_int_le(w_tile.data[0][0][0], k), - extract_8bit_from_packed_int_le(w_tile.data[0][0][1], k), - extract_8bit_from_packed_int_le(w_tile.data[0][0][2], k), - extract_8bit_from_packed_int_le(w_tile.data[0][0][3], k)); - - for (int m = 0; m < TILE_M; ++m) { - accum.data[m][0] = - fma(VEC4_T(in_tile.data[m][0][k]), weight_texel, accum.data[m][0]); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + [[unroll]] for (int k4i = 0; k4i < 4; ++k4i) { + // Unpack one column of weights + weight_texel = VEC4_T( + extract_8bit_from_packed_int_le(w_tile.data[k4][n4][0], k4i), + extract_8bit_from_packed_int_le(w_tile.data[k4][n4][1], k4i), + extract_8bit_from_packed_int_le(w_tile.data[k4][n4][2], k4i), + extract_8bit_from_packed_int_le(w_tile.data[k4][n4][3], k4i)); + + for (int m = 0; m < TILE_M; ++m) { + accum.data[m][n4] = + fma(VEC4_T(in_tile.data[m][k4][k4i]), + weight_texel, + accum.data[m][n4]); + } + } } } - -#else - // TODO(ssjia): implement the general case - not implemented - -#endif } #endif // LINEAR_FP_OUTPUT_TILE_FP_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh new file mode 100644 index 00000000000..ac886f78bfb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int4_compute.glslh @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_FP_OUTPUT_TILE_INT8_INT4_COMPUTE_GLSLH +#define LINEAR_FP_OUTPUT_TILE_INT8_INT4_COMPUTE_GLSLH + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_integer_dot_product : require + +#include "linear_int4_weight_tile.glslh" +#include "linear_int8_input_tile.glslh" +#include "linear_int_accumulator.glslh" + +void int_accumulate_with_int4_weight( + inout Int32Accum accum, + Int8InputTile in_tile, + Int4WeightTile w_tile) { + // Accum tile is indexed as accum[m][n4][n4i] + // -> gives integer accumulator for output tile element at (x = n, y = m) + // Input tile is indexed as in_tile.data[m4][k4][m4i] + // -> gives packed integer containing the 4x 8-bit quantized values at index + // (k, m), (k + 1, m), (k + 2, m), (k + 3, m) + // Weight tile is indexed as w_tile.data[k4][n8][n4i] + // -> gives packed integer containing the 8x 4-bit quantized values + [[unroll]] for (int n8 = 0; n8 < TILE_N8; ++n8) { + const int n4_base = mul_2(n8); + [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { + ivec4 unpacked_block_1 = w_tile.data[k4][n8] & 0x0F0F0F0F; + ivec4 unpacked_block_2 = (w_tile.data[k4][n8] >> 4) & 0x0F0F0F0F; + + [[unroll]] for (int n4i = 0; n4i < 4; ++n4i) { + // Accumulate unpacked_block_1[n4i] and unpacked_block_2[n4i] with + // each row of the input tile + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + const int m4 = div_4(m); + const int m4i = mod_4(m); + accum.data[m][n4_base][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + unpacked_block_1[n4i], + accum.data[m][n4_base][n4i]); + accum.data[m][n4_base + 1][n4i] = dotPacked4x8AccSatEXT( + in_tile.data[m4][k4][m4i], + unpacked_block_2[n4i], + accum.data[m][n4_base + 1][n4i]); + } + } + } + } +} + +void accumulate_out_tile_with_int_accum_from_int4_weights( + inout FPOutTile out_tile, + const Int32Accum accum, + const IntPerInChannelParams input_sums, + const Int8InputScales input_scales, + const Int8InputZeroPoints input_zps, + const IntPerOutChannelParams weight_sums, + const FPPerOutChannelParams weight_scales, + const int group_size) { + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + float input_scale_m = input_scales.data[0][m]; + int input_zp_m = input_zps.data[0][m]; + int input_sum_m = input_sums.data[0][m]; + + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + ivec4 accum_adjusted = accum.data[m][n4] - + (weight_sums.data[n4] + 8 * group_size) * input_zp_m + + mul_8(group_size * input_zp_m - input_sum_m); + + out_tile.data[m][n4] = + fma(VEC4_T(accum_adjusted), + input_scale_m * weight_scales.data[n4], + out_tile.data[m][n4]); + } + } +} + +#endif // LINEAR_FP_OUTPUT_TILE_INT8_INT4_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh index b04074eba75..68ac269e9d7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_int8_int8_compute.glslh @@ -23,31 +23,13 @@ #include "linear_common.glslh" #include "linear_fp_output_tile.glslh" #include "linear_fp_per_out_channel_params.glslh" +#include "linear_int8_input_scales_zps.glslh" #include "linear_int8_input_tile.glslh" #include "linear_int8_weight_tile.glslh" +#include "linear_int_accumulator.glslh" +#include "linear_int_per_in_channel_params.glslh" #include "linear_int_per_out_channel_params.glslh" -// Stores integer accumulators for an output tile. -struct Int32Accum { - ivec4 data[TILE_M][TILE_N4]; -}; - -// Initialize values to 0 -void initialize(out Int32Accum out_accum) { -#if TILE_N4 == 1 - [[unroll]] for (int y = 0; y < TILE_M; ++y) { - out_accum.data[y][0] = ivec4(0); - } - -#else - [[unroll]] for (int y = 0; y < TILE_M; ++y) { - [[unroll]] for (int x4 = 0; x4 < TILE_K4; ++x4) { - out_accum.data[y][x4] = ivec4(0); - } - } -#endif -} - // Accumulate int8 input and weight tiles into integer accumulator tile void int_accumulate_with_int8_weight( inout Int32Accum accum, @@ -61,23 +43,6 @@ void int_accumulate_with_int8_weight( // Weight tile is indexed as w_tile.data[k4][n4][n4i] // -> gives packed integer containing the 4x 8-bit quantized values at index // (n, k), (n, k + 1), (n, k + 2), (n, k + 3) -#if TILE_M4 == 1 && TILE_K4 == 1 && TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - // n = 0 - accum.data[m][0][0] = dotPacked4x8AccSatEXT( - in_tile.data[0][0][m], w_tile.data[0][0][0], accum.data[m][0][0]); - // n = 1 - accum.data[m][0][1] = dotPacked4x8AccSatEXT( - in_tile.data[0][0][m], w_tile.data[0][0][1], accum.data[m][0][1]); - // n = 2 - accum.data[m][0][2] = dotPacked4x8AccSatEXT( - in_tile.data[0][0][m], w_tile.data[0][0][2], accum.data[m][0][2]); - // n = 3 - accum.data[m][0][3] = dotPacked4x8AccSatEXT( - in_tile.data[0][0][m], w_tile.data[0][0][3], accum.data[m][0][3]); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { const int m4 = div_4(m); const int m4i = mod_4(m); @@ -92,16 +57,8 @@ void int_accumulate_with_int8_weight( } } } - -#endif } -/* - * Computes final weight matrix output tile using: - * - int8 accumulator tile - * - per output channel weight sums - * - per output channel scales - */ void accumulate_out_tile_with_int_accum( inout FPOutTile out_tile, const Int32Accum accum, @@ -110,25 +67,21 @@ void accumulate_out_tile_with_int_accum( const IntPerOutChannelParams weight_sums, const FPPerOutChannelParams weight_scales) { ivec4 input_zp_vec = ivec4(-input_q_zp); -#if TILE_N4 == 1 [[unroll]] for (int m = 0; m < TILE_M; ++m) { - // Unfortunately fma doesn't work with ivec4. Prefer to preserve integer - // format for as long as possible to avoid precision loss. - ivec4 accum_adjusted = - input_zp_vec * weight_sums.data[0] + accum.data[m][0]; - out_tile.data[m][0] = - fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[0], - out_tile.data[m][0]); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + // Unfortunately fma doesn't work with ivec4. Prefer to preserve integer + // format for as long as possible to avoid precision loss. + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + out_tile.data[m][n4] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[0], + out_tile.data[m][n4]); + } } - -#else - // TODO(ssjia): Implement the general case - not implemented - -#endif } +// overload of the above but with bias void accumulate_out_tile_with_int_accum( inout FPOutTile out_tile, const Int32Accum accum, @@ -138,42 +91,18 @@ void accumulate_out_tile_with_int_accum( const FPPerOutChannelParams weight_scales, const FPPerOutChannelParams bias) { ivec4 input_zp_vec = ivec4(-input_q_zp); -#if TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - // Apply scale and zero points to the int accumulator - ivec4 accum_adjusted = - input_zp_vec * weight_sums.data[0] + accum.data[m][0]; - out_tile.data[m][0] = - fma(VEC4_T(accum_adjusted), - input_q_scale * weight_scales.data[0], - out_tile.data[m][0]); - out_tile.data[m][0] += bias.data[0]; - } - -#else - // TODO(ssjia): Implement the general case - not implemented - -#endif -} - -#ifdef DEBUG_MODE - -void printInt32Accum(const Int32Accum tile) { - debugPrintfEXT("int accum: \\n"); [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - debugPrintfEXT( - " %d, %d, %d, %d,", - tile.data[m][n4].x, - tile.data[m][n4].y, - tile.data[m][n4].z, - tile.data[m][n4].w); + // Apply scale and zero points to the int accumulator + ivec4 accum_adjusted = + input_zp_vec * weight_sums.data[n4] + accum.data[m][n4]; + out_tile.data[m][n4] = + fma(VEC4_T(accum_adjusted), + input_q_scale * weight_scales.data[n4], + out_tile.data[m][n4]); + out_tile.data[m][n4] += bias.data[n4]; } - debugPrintfEXT("\\n"); } } -#endif - #endif // LINEAR_FP_OUTPUT_TILE_INT8_INT8_COMPUTE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh index a4019204cc3..6fb399ff99b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_output_tile_store.glslh @@ -24,45 +24,28 @@ #include "linear_fp_output_tile.glslh" -#ifdef OUTPUT_BUFFER - void write_output_x4( const VEC4_T out_texel, const int n4, const int m, const int N4) { +#ifdef OUTPUT_BUFFER t_output[m * N4 + n4] = out_texel; -} - #else - -void write_output_x4( - const VEC4_T out_texel, - const int n4, - const int m, - const int N4) { imageStore(t_output, ivec3(n4, m, 0), out_texel); +#endif } -#endif // OUTPUT_BUFFER - void write_output_tile( const FPOutTile out_tile, const int n4_start, const int m_start, const int N4) { -#if TILE_K4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); } } -#endif } // To be used if M - m >= TILE_M && N4 - n4 >= TILE_N4 @@ -72,18 +55,11 @@ void write_output_tile_no_checks( const int m_start, const int N4, const int M) { -#if TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { write_output_x4(out_tile.data[m][n4], n4_start + n4, m_start + m, N4); } } -#endif } // To be used if close to tensor boundaries @@ -93,14 +69,6 @@ void write_output_tile_with_checks( const int m_start, const int N4, const int M) { -#if TILE_N4 == 1 - [[unroll]] for (int m = 0; m < TILE_M; ++m) { - if (m_start + m < M) { - write_output_x4(out_tile.data[m][0], n4_start, m_start + m, N4); - } - } - -#else [[unroll]] for (int m = 0; m < TILE_M; ++m) { [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { if (m_start + m < M && n4_start + n4 < N4) { @@ -108,7 +76,6 @@ void write_output_tile_with_checks( } } } -#endif } #endif // LINEAR_FP_OUTPUT_TILE_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh index 72b22988414..96fb3e9900b 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_per_out_channel_params.glslh @@ -6,11 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - * Defines common functions and structs to be used across matrix multiplication - * operators. - */ - #ifndef LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH #define LINEAR_FP_PER_OUT_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh index 1286c1d082f..6011a7c6201 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_scales_load.glslh @@ -22,15 +22,9 @@ VEC4_T load_scale_x4(const int n4, const int quant_group_idx, const int N4) { void load_weight_scales_tile( out FPPerOutChannelParams scales, const int n4_start) { -#if TILE_N4 == 1 - scales.data[0] = load_weight_scale_x4(n4_start); - -#else [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { scales.data[n4] = load_weight_scale_x4(n4_start + n4); } - -#endif } void load_weight_scales_tile_for_group( @@ -38,15 +32,9 @@ void load_weight_scales_tile_for_group( const int n4_start, const int quant_group_idx, const int N4) { -#if TILE_N4 == 1 - scales.data[0] = load_scale_x4(n4_start, quant_group_idx, N4); - -#else [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { scales.data[n4] = load_scale_x4(n4_start + n4, quant_group_idx, N4); } - -#endif } #endif // LINEAR_FP_WEIGHT_SCALES_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh index f44bbbc1565..5e010442540 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_fp_weight_tile.glslh @@ -7,12 +7,9 @@ */ /* - * Defines the FPWeightTile struct, which is used to represent a fp tile of a - * weight matrix in matrix multiplication. - * - * Settings: - * - TILE_K: number of rows in the output tile - * - TILE_N4: number of (groups of 4) columns in the output tile + * Macro Settings: + * - TILE_K + * - TILE_N4 */ #ifndef LINEAR_FP_WEIGHT_TILE_GLSLH @@ -26,61 +23,6 @@ struct FPWeightTile { VEC4_T data[TILE_K][TILE_N4]; }; -#ifdef LINEAR_INT8_WEIGHT_TILE_GLSLH - -int sign_extend(const int val) { - if ((val & 0x80) != 0) { - return val | (~0xFF); - } - return val; -} - -T extract_8bit_value(const Int8WeightTile w_tile, const int k, const int n) { -#if TILE_K4 == 1 && TILE_N4 == 1 - const int k4i = k; - const int n4i = n; - ivec4 block = w_tile.data[0][0]; - -#else - const int k4 = div_4(k); - const int k4i = mod_4(k); - - const int n4 = div_4(n); - const int n4i = mod_4(n); - - ivec4 block = w_tile.data[k4][n4]; -#endif - - int col = block[n4i]; - int val = (col >> (k4i * 8)) & 0xFF; - - return T(sign_extend(val)); -} - -void unpack(out FPWeightTile fp_w_tile, const Int8WeightTile w_tile) { -#if TILE_K > 1 && TILE_N4 == 1 - [[unroll]] for (int k = 0; k < TILE_K; ++k) { - fp_w_tile.data[k][0][0] = extract_8bit_value(w_tile, k, 0); - fp_w_tile.data[k][0][1] = extract_8bit_value(w_tile, k, 1); - fp_w_tile.data[k][0][2] = extract_8bit_value(w_tile, k, 2); - fp_w_tile.data[k][0][3] = extract_8bit_value(w_tile, k, 3); - } - -#else - [[unroll]] for (int k = 0; k < TILE_M; ++k) { - [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { - const int n = mul_4(n4); - fp_w_tile.data[k][n4][0] = extract_8bit_value(w_tile, k, n); - fp_w_tile.data[k][n4][1] = extract_8bit_value(w_tile, k, n + 1); - fp_w_tile.data[k][n4][2] = extract_8bit_value(w_tile, k, n + 2); - fp_w_tile.data[k][n4][3] = extract_8bit_value(w_tile, k, n + 3); - } - } -#endif -} - -#endif // LINEAR_INT8_WEIGHT_TILE_GLSLH - #ifdef DEBUG_MODE void printFPWeightTile(const FPWeightTile tile) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh index d813224c3aa..38b805a7c49 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_block.glslh @@ -10,20 +10,12 @@ #define LINEAR_INT4_WEIGHT_BLOCK_GLSLH /* - * This file defines utilties to perform weight prepacking of quantized int4 - * matrix multiplation weights. It also defines utilities to load source - * weight data from inputbuffer, and write out a packed weight block to output - * texture/buffer. + * Assumes the following variables are defined in shader layout + * - t_packed_int4_weight + * - t_int4_weight * - * Note: 2 4-bit values are packed into each 8-bit value in the source data. - * - * Requires: - * - t_packed_int4_weight to be defined in shader layout (output texture/buffer) - * - t_int4_weight to be defined in shader layout (input buffer) - * - * Settings: - * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture - * is assumed. + * Macro Settings: + * - USING_BUFFER */ #extension GL_EXT_control_flow_attributes : require @@ -67,6 +59,8 @@ void load_block_source_data_with_checks( if (n_start + n < N) { src_data.data[n] = t_int4_weight[(n_start + n) * ntexels_K + k8]; } else { + // Equivalent to a row of zeros since int4 weights have an implicit zero + // point of -8. src_data.data[n] = 0x88888888; } } @@ -116,27 +110,17 @@ void create_packed_blocks( } } -#ifdef USING_BUFFER - void write_packed_block( const Int4WeightBlockPacked block, const int k4, const int n8, const int nblocks_K) { +#ifdef USING_BUFFER t_packed_int4_weight[n8 * nblocks_K + k4] = block.data; -} - -#else // USING_TEXTURE - -void write_packed_block( - const Int4WeightBlockPacked block, - const int k4, - const int n8, - const int nblocks_K) { +#else imageStore(t_packed_int4_weight, ivec2(k4, n8), block.data); -} - #endif // USING_BUFFER +} #ifdef DEBUG_MODE diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh index 559459f14a8..c51259e1ae8 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile.glslh @@ -11,6 +11,7 @@ #include "linear_common.glslh" #include "linear_fp_weight_tile.glslh" +#include "linear_int8_weight_tile.glslh" /* * Defines the Int4WeightTile struct, which is used to represent a tile of the @@ -27,64 +28,6 @@ struct Int4WeightTile { ivec4 data[TILE_K4][TILE_N8]; }; -void unpack_int4_weight_tile( - out FPWeightTile int8_tile, - const Int4WeightTile int4_tile) { -#if TILE_K4 == 1 && TILE_N8 == 1 - for (int k = 0; k < TILE_K; ++k) { - const int col_idx_1 = 2 * k; - const int col_idx_2 = 2 * k + 1; - int8_tile.data[k][0][0] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][0], col_idx_1)); - int8_tile.data[k][0][1] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][1], col_idx_1)); - int8_tile.data[k][0][2] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][2], col_idx_1)); - int8_tile.data[k][0][3] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][3], col_idx_1)); - - // n4 = 1 - int8_tile.data[k][1][0] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][0], col_idx_2)); - int8_tile.data[k][1][1] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][1], col_idx_2)); - int8_tile.data[k][1][2] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][2], col_idx_2)); - int8_tile.data[k][1][3] = - T(extract_4bit_from_packed_int_le(int4_tile.data[0][0][3], col_idx_2)); - } - -#else - for (int k = 0; k < TILE_K; ++k) { - const int k4 = div_4(k); - const int k4i = mod_4(k); - for (int n8 = 0; n8 < TILE_N8; ++n8) { - const int n4 = mul_2(n8); - const int col_idx_1 = 2 * k4i; - const int col_idx_2 = 2 * k4i + 1; - int8_tile.data[k][n4][0] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][0], col_idx_1)); - int8_tile.data[k][n4][1] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][1], col_idx_1)); - int8_tile.data[k][n4][2] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][2], col_idx_1)); - int8_tile.data[k][n4][3] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][3], col_idx_1)); - - int8_tile.data[k][n4 + 1][0] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][0], col_idx_2)); - int8_tile.data[k][n4 + 1][1] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][1], col_idx_2)); - int8_tile.data[k][n4 + 1][2] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][2], col_idx_2)); - int8_tile.data[k][n4 + 1][3] = T(extract_4bit_from_packed_int_le( - int4_tile.data[k4][n8][3], col_idx_2)); - } - } - -#endif -} - #ifdef DEBUG_MODE void printInt4WeightTile(const Int4WeightTile block) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh index 033e0082436..965496d70d7 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int4_weight_tile_load.glslh @@ -10,69 +10,39 @@ #define LINEAR_INT4_WEIGHT_TILE_LOAD_GLSLH /* - * Defines functions to load a Int4WeightTile from input buffer/texture. + * Assumes the following variables are declared in shader layout: + * - t_packed_int4_weight * - * Requires: - * - t_packed_int4_weight to be declared in the shader layout (input - * buffer/texture) - * - * Settings: - * - WEIGHT_BUFFER to indicate t_packed_int4_weight is a buffer, otherwise - * texture storage is assumed. + * Macro Settings: + * - WEIGHT_BUFFER */ #extension GL_EXT_control_flow_attributes : require #include "linear_int4_weight_tile.glslh" -#ifdef WEIGHT_BUFFER - ivec4 load_int4_weight_block( const int block_x, const int block_y, const int nblocks_x) { +#ifdef WEIGHT_BUFFER return t_packed_int4_weight[(block_y * nblocks_x) + block_x]; -} - -#else // WEIGHT_TEXTURE - -ivec4 load_int4_weight_block( - const int block_x, - const int block_y, - const int nblocks_x) { +#else return texelFetch(t_packed_int4_weight, ivec2(block_x, block_y), 0); +#endif } -#endif // WEIGHT_BUFFER - void load_int4_weight_tile( out Int4WeightTile weight_tile, const int block_x, const int block_y, const int nblocks_x) { -#if TILE_K4 == 1 && TILE_N8 == 1 - weight_tile.data[0][0] = load_int4_weight_block(block_x, block_y, nblocks_x); - -#elif TILE_K4 == 1 && TILE_N8 > 1 - [[unroll]] for (int x = 0; x < TILE_N8; ++x) { - weight_tile.data[0][x] = - load_int4_weight_block(block_x + x, block_y, nblocks_x); - } - -#elif TILE_K4 > 1 && TILE_N8 == 1 - [[unroll]] for (int y = 0; y < TILE_K4; ++y) { - weight_tile.data[y][0] = - load_int4_weight_block(block_x, block_y + y, nblocks_x); - } - -#else [[unroll]] for (int y = 0; y < TILE_K4; ++y) { [[unroll]] for (int x = 0; x < TILE_N8; ++x) { weight_tile.data[y][x] = load_int4_weight_block(block_x + x, block_y + y, nblocks_x); } } -#endif } #endif // LINEAR_INT4_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh index 9535de21f7b..a6dbd7e78a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_block.glslh @@ -22,10 +22,13 @@ #ifndef LINEAR_INT8_INPUT_BLOCK_GLSLH #define LINEAR_INT8_INPUT_BLOCK_GLSLH -#define TILE_M 4 +#define TILE_M4 1 #define TILE_K4 1 +#define TILE_M 4 + #include "linear_fp_input_tile.glslh" +#include "linear_int8_input_scales_zps.glslh" struct Int8InputBlock { ivec4 data; @@ -59,26 +62,30 @@ void quantize_and_pack( } } -#ifdef OUTPUT_BUFFER +void quantize_and_pack( + out Int8InputBlock packed, + const FPInputTile in_block, + const Int8InputScales scales, + const Int8InputZeroPoints zps) { + for (int m = 0; m < 4; ++m) { + const float q_inv_scale = 1.0 / float(scales.data[0][m]); + const int q_zero_point = zps.data[0][m]; + ivec4 quantized_inputs = + quantize(in_block.data[m][0], q_inv_scale, q_zero_point); + packed.data[m] = pack_into_int32(quantized_inputs); + } +} void write_block( const Int8InputBlock block, const int block_x, const int block_y, const int nblocks_x) { +#ifdef OUTPUT_BUFFER t_packed_int8_input[block_y * nblocks_x + block_x] = block.data; -} - #else // OUTPUT_TEXTURE - -void write_block( - const Int8InputBlock block, - const int block_x, - const int block_y, - const int nblocks_x) { imageStore(t_packed_int8_input, ivec3(block_x, block_y, 0), block.data); -} - #endif // OUTPUT_BUFFER +} #endif // LINEAR_INT8_INPUT_BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps.glslh new file mode 100644 index 00000000000..6c0bb0c2add --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps.glslh @@ -0,0 +1,59 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Defines common functions and structs to be used across matrix multiplication + * operators. + */ + +#ifndef LINEAR_INT8_INPUT_SCALES_ZPS_GLSLH +#define LINEAR_INT8_INPUT_SCALES_ZPS_GLSLH + +#include "common.glslh" + +#extension GL_EXT_control_flow_attributes : require + +struct Int8InputScales { + VEC4_T data[TILE_M4]; +}; + +struct Int8InputZeroPoints { + ivec4 data[TILE_M4]; +}; + +#ifdef DEBUG_MODE + +void printInt8InputScales(const Int8InputScales scales) { + debugPrintfEXT("input_scales: \\n"); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + debugPrintfEXT( + " %f, %f, %f, %f, ", + scales.data[m4].x, + scales.data[m4].y, + scales.data[m4].z, + scales.data[m4].w); + } + debugPrintfEXT("\\n"); +} + +void printInt8InputZeroPoints(const Int8InputZeroPoints zero_points) { + debugPrintfEXT("input_zero_points: \\n"); + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + debugPrintfEXT( + " %d, %d, %d, %d, ", + zero_points.data[m4].x, + zero_points.data[m4].y, + zero_points.data[m4].z, + zero_points.data[m4].w); + } + debugPrintfEXT("\\n"); +} + +#endif // DEBUG_MODE + +#endif // LINEAR_INT8_INPUT_SCALES_ZPS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh new file mode 100644 index 00000000000..e1a570622c2 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_scales_zps_load.glslh @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT8_INPUT_SCALES_ZPS_LOAD_GLSLH +#define LINEAR_INT8_INPUT_SCALES_ZPS_LOAD_GLSLH + +#include "linear_int8_input_scales_zps.glslh" + +#extension GL_EXT_control_flow_attributes : require + +void load_int8_input_scales_and_zps( + out Int8InputScales scales, + out Int8InputZeroPoints zps, + const int m4_start) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; m4++) { + scales.data[m4] = + VEC4_T(texelFetch(t_int8_input_scales, ivec3(m4_start + m4, 0, 0), 0)); + zps.data[m4] = texelFetch(t_int8_input_zps, ivec3(m4_start + m4, 0, 0), 0); + } +} + +#endif // LINEAR_INT8_INPUT_SCALES_ZPS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_load.glslh new file mode 100644 index 00000000000..4bb2fa12fcb --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_load.glslh @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT_INPUT_SUMS_LOAD_GLSLH +#define LINEAR_INT_INPUT_SUMS_LOAD_GLSLH + +#include "linear_int_per_in_channel_params.glslh" + +#extension GL_EXT_control_flow_attributes : require + +ivec4 load_int8_input_sum_x4( + const int m4, + const int quant_group_idx, + const int M4) { + return t_int8_input_sums[quant_group_idx * M4 + m4]; +} + +void load_int8_input_sums_tile_for_group( + out IntPerInChannelParams sums, + const int m4_start, + const int quant_group_idx, + const int M4) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + sums.data[m4] = load_int8_input_sum_x4(m4_start + m4, quant_group_idx, M4); + } +} + +#endif // LINEAR_INT_INPUT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_store.glslh new file mode 100644 index 00000000000..c9066be8e74 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_sums_store.glslh @@ -0,0 +1,26 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT_INPUT_SUMS_LOAD_GLSLH +#define LINEAR_INT_INPUT_SUMS_LOAD_GLSLH + +#include "linear_int_per_in_channel_params.glslh" + +#extension GL_EXT_control_flow_attributes : require + +void store_int8_input_sums_tile_for_group( + out IntPerInChannelParams sums, + const int m4_start, + const int quant_group_idx, + const int M4) { + [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { + t_int8_input_sums[quant_group_idx * M4 + m4] = sums.data[m4]; + } +} + +#endif // LINEAR_INT_INPUT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh index 89a7e1b3f89..5d8f78bae7c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile.glslh @@ -7,12 +7,9 @@ */ /* - * Defines the Int8InputTile struct, which is used to represent a tile of the - * quantized int8 input matrix of a quantized matrix multiplication operation. - * - * Settings: - * - TILE_M4: number of (groups of 4) rows in the tile - * - TILE_K4: number of (groups of 4) columns in the tile + * Macro Settings: + * - TILE_M4 + * - TILE_K4 */ #ifndef LINEAR_INT8_INPUT_TILE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh index c79badab6c6..fc7c3c9c181 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_input_tile_load.glslh @@ -7,14 +7,11 @@ */ /* - * Defines functions to load a Int8InputTile from input buffer/texture. + * Assumes the following variables are defined in the shader layout: + * - t_packed_int8_input * - * Requires: - * - t_packed_int8_input to be declared in the shader layout - * - * Settings: - * - PACKED_INT8_INPUT_BUFFER to indicate resource is a buffer, otherwise - * texture storage is assumed. + * Macro Settings: + * - PACKED_INT8_INPUT_BUFFER */ #ifndef LINEAR_INT8_INPUT_TILE_LOAD_GLSLH @@ -24,52 +21,28 @@ #include "linear_int8_input_tile.glslh" -#ifdef PACKED_INT8_INPUT_BUFFER - ivec4 load_int8_input_block( const int block_x, const int block_y, const int nblocks_x) { +#ifdef PACKED_INT8_INPUT_BUFFER return t_packed_int8_input[(block_y * nblocks_x) + block_x]; -} - #else - -ivec4 load_int8_input_block( - const int block_x, - const int block_y, - const int nblocks_x) { return texelFetch(t_packed_int8_input, ivec3(block_x, block_y, 0), 0); -} - #endif // PACKED_INT8_INPUT_BUFFER +} void load_int8_input_tile( out Int8InputTile in_tile, const int block_x, const int block_y, const int nblocks_x) { -#if TILE_M4 == 1 && TILE_K4 == 1 - in_tile.data[0][0] = load_int8_input_block(block_x, block_y, nblocks_x); - -#elif TILE_M4 == 1 && TILE_K4 > 1 - [[unroll]] for (int x = 0; x < TILE_K4; ++x) { - in_tile.data[0][x] = load_int8_input_block(block_x + x, block_y, nblocks_x); - } - -#elif TILE_M4 > 1 && TILE_K4 == 1 - [[unroll]] for (int y = 0; y < TILE_M4; ++y) { - in_tile.data[y][0] = load_int8_input_block(block_x, block_y + y, nblocks_x); - } - -#else [[unroll]] for (int y = 0; y < TILE_M4; ++y) { [[unroll]] for (int x = 0; x < TILE_K4; ++x) { in_tile.data[y][x] = load_int8_input_block(block_x + x, block_y + y, nblocks_x); } } -#endif } #endif // LINEAR_INT8_INPUT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh index 6e98caea49e..6fa51452057 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_block.glslh @@ -10,18 +10,12 @@ #define LINEAR_INT8_WEIGHT_BLOCK_GLSLH /* - * This file defines utilties to perform weight prepacking of quantized int8 - * matrix multiplation weights. It also defines utilities to load source - * weight data from inputbuffer, and write out a packed weight block to output - * texture/buffer. + * Assumes the following variables are defined in the shader layout + * - t_packed_int8_weight + * - t_int8_weight * - * Requires: - * - t_packed_int8_weight to be defined in shader layout (output texture/buffer) - * - t_int8_weight to be defined in shader layout (input buffer) - * - * Settings: - * - USING_BUFFER to indicate if output resource is a buffer. Otherwise texture - * is assumed. + * Macro Settings: + * - USING_BUFFER */ #extension GL_EXT_control_flow_attributes : require @@ -60,28 +54,18 @@ void load_block_data_with_checks( } } -#ifdef USING_BUFFER - void write_weight_block( const Int8WeightBlock block, const int n4, const int k4, const int ntexels_N) { +#ifdef USING_BUFFER t_packed_int8_weight[k4 * ntexels_N + n4] = block.data; -} - -#else // USING_TEXTURE - -void write_weight_block( - const Int8WeightBlock block, - const int n4, - const int k4, - const int ntexels_N) { +#else imageStore(t_packed_int8_weight, ivec2(n4, k4), block.data); +#endif } -#endif // USING_BUFFER - #ifdef DEBUG_MODE void printInt8WeightBlock(const Int8WeightBlockPacked block) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh index f312db543db..6b3d4d4419e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile.glslh @@ -10,12 +10,9 @@ #define LINEAR_INT8_WEIGHT_TILE_GLSLH /* - * Defines the Int8WeightTile struct, which is used to represent a tile of the - * quantized int8 weight matrix of a quantized matrix multiplication operation. - * - * Settings: - * - TILE_K4: number of (groups of 4) rows in the weight tile - * - TILE_N4: number of (groups of 4) columns in the weight tile + * Macro Settings: + * - TILE_K4 + * - TILE_N4 */ #extension GL_EXT_control_flow_attributes : require @@ -31,13 +28,13 @@ void printInt8WeightTile(const Int8WeightTile tile) { "Int8WeightTile [TILE_K4=%d][TILE_N4=%d]:\\n", TILE_K4, TILE_N4); [[unroll]] for (int m4 = 0; m4 < TILE_M4; ++m4) { - [[unroll]] for (int k4 = 0; k4 < TILE_K4; ++k4) { - debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, k4); + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT(" tile[%d][%d] (ivec4): ", m4, n4); // Each ivec4 contains 4 packed integers, each integer contains 4 8-bit // values [[unroll]] for (int vec_idx = 0; vec_idx < 4; ++vec_idx) { - int packed_int = tile.data[m4][k4][vec_idx]; + int packed_int = tile.data[m4][n4][vec_idx]; debugPrintfEXT("packed_int[%d]=%d -> [", vec_idx, packed_int); // Extract 4 8-bit values from this packed integer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh index fe16d3469b3..5021fe24d6f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int8_weight_tile_load.glslh @@ -10,69 +10,39 @@ #define LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH /* - * Defines functions to load a Int8WeightTile from input buffer/texture. + * Assumes the following variables are defined in the shader layout: + * - t_packed_int8_weight * - * Requires: - * - t_packed_int8_weight to be declared in the shader layout (input - * buffer/texture) - * - * Settings: - * - WEIGHT_BUFFER to indicate t_packed_int8_weight is a buffer, otherwise - * texture storage is assumed. + * Macro Settings: + * - WEIGHT_BUFFER */ #extension GL_EXT_control_flow_attributes : require #include "linear_int8_weight_tile.glslh" -#ifdef WEIGHT_BUFFER - ivec4 load_int8_weight_block( const int block_x, const int block_y, const int nblocks_x) { +#ifdef WEIGHT_BUFFER return t_packed_int8_weight[(block_y * nblocks_x) + block_x]; -} - -#else // WEIGHT_TEXTURE - -ivec4 load_int8_weight_block( - const int block_x, - const int block_y, - const int nblocks_x) { +#else return texelFetch(t_packed_int8_weight, ivec2(block_x, block_y), 0); -} - #endif // WEIGHT_BUFFER +} void load_int8_weight_tile( out Int8WeightTile weight_tile, const int block_x, const int block_y, const int nblocks_x) { -#if TILE_K4 == 1 && TILE_N4 == 1 - weight_tile.data[0][0] = load_int8_weight_block(block_x, block_y, nblocks_x); - -#elif TILE_K4 == 1 && TILE_N4 > 1 - [[unroll]] for (int x = 0; x < TILE_N4; ++x) { - weight_tile.data[0][x] = - load_int8_weight_block(block_x + x, block_y, nblocks_x); - } - -#elif TILE_K4 > 1 && TILE_N4 == 1 - [[unroll]] for (int y = 0; y < TILE_M4; ++y) { - weight_tile.data[y][0] = - load_int8_weight_block(block_x, block_y + y, nblocks_x); - } - -#else [[unroll]] for (int y = 0; y < TILE_K4; ++y) { [[unroll]] for (int x = 0; x < TILE_N4; ++x) { weight_tile.data[y][x] = load_int8_weight_block(block_x + x, block_y + y, nblocks_x); } } -#endif } #endif // LINEAR_INT8_WEIGHT_TILE_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_accumulator.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_accumulator.glslh new file mode 100644 index 00000000000..3bfc27e6f54 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_accumulator.glslh @@ -0,0 +1,46 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT_ACCUM_GLSLH +#define LINEAR_INT_ACCUM_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +// Stores integer accumulators for an output tile. +struct Int32Accum { + ivec4 data[TILE_M][TILE_N4]; +}; + +void initialize(out Int32Accum out_accum) { + [[unroll]] for (int y = 0; y < TILE_M; ++y) { + [[unroll]] for (int x4 = 0; x4 < TILE_N4; ++x4) { + out_accum.data[y][x4] = ivec4(0); + } + } +} + +#ifdef DEBUG_MODE + +void printInt32Accum(const Int32Accum tile) { + debugPrintfEXT("int accum: \\n"); + [[unroll]] for (int m = 0; m < TILE_M; ++m) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + debugPrintfEXT( + " %d, %d, %d, %d,", + tile.data[m][n4].x, + tile.data[m][n4].y, + tile.data[m][n4].z, + tile.data[m][n4].w); + } + debugPrintfEXT("\\n"); + } +} + +#endif + +#endif // LINEAR_INT_ACCUM_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_in_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_in_channel_params.glslh new file mode 100644 index 00000000000..06ac91ca92f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_in_channel_params.glslh @@ -0,0 +1,22 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef LINEAR_INT_PER_IN_CHANNEL_PARAMS_GLSLH +#define LINEAR_INT_PER_IN_CHANNEL_PARAMS_GLSLH + +#extension GL_EXT_control_flow_attributes : require + +struct IntPerInChannelParams { + ivec4 data[TILE_M4]; +}; + +void initialize(out IntPerInChannelParams params) { + [[unroll]] for (int i = 0; i < TILE_M4; i++) { params.data[i] = ivec4(0); } +} + +#endif // LINEAR_INT_PER_IN_CHANNEL_PARAMS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh index ca29fd52780..73a079d2479 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_per_out_channel_params.glslh @@ -6,11 +6,6 @@ * LICENSE file in the root directory of this source tree. */ -/* - * Defines common functions and structs to be used across matrix multiplication - * operators. - */ - #ifndef LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH #define LINEAR_INT_PER_OUT_CHANNEL_PARAMS_GLSLH @@ -19,7 +14,7 @@ #extension GL_EXT_control_flow_attributes : require // Represents floating point parameter tensors where each element is associated -// with an output channel, such as weight scales, biases, etc. +// with an output channel, such as weight sums struct IntPerOutChannelParams { ivec4 data[TILE_N4]; }; diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh index 1a17f99ea4e..2fbc13e1bdb 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_int_weight_sums_load.glslh @@ -15,18 +15,29 @@ ivec4 load_weight_sum_x4(const int n4) { return ivec4(t_weight_sums[n4]); } +ivec4 load_weight_sum_x4( + const int n4, + const int quant_group_idx, + const int N4) { + return t_weight_sums[quant_group_idx * N4 + n4]; +} + void load_weight_sums_tile( out IntPerOutChannelParams sums, const int n4_start) { -#if TILE_N4 == 1 - sums.data[0] = load_weight_sum_x4(n4_start); - -#else [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { sums.data[n4] = load_weight_sum_x4(n4_start + n4); } +} -#endif +void load_weight_sums_tile_for_group( + out IntPerOutChannelParams sums, + const int n4_start, + const int quant_group_idx, + const int N4) { + [[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) { + sums.data[n4] = load_weight_sum_x4(n4_start + n4, quant_group_idx, N4); + } } #endif // LINEAR_FP_WEIGHT_SUMS_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl index 6f0d890a9c4..c9b82425865 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.glsl @@ -33,11 +33,23 @@ ${define_required_extensions(DTYPE)} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} +$if DYNAMIC_QUANT_VARIANT: + ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_packed_int8_input", "int", PACKED_INPUT_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_int_input_sums", "int", "buffer", is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_input_scale", DTYPE, "texture3d")} + ${layout_declare_tensor(B, "r", "t_input_zp", "int", "texture3d")} + ${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_weight_sums", "int", "buffer", is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} +$else: + ${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_packed_int4_weight", "int", WEIGHT_STORAGE, is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_weight_scales", DTYPE, "buffer", is_scalar_array=False)} + ${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer", is_scalar_array=False)} ${layout_declare_ubo(B, "ivec4", "output_sizes")} ${layout_declare_ubo(B, "ivec4", "input_sizes")} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml index bb5f44d4086..2c5001fdd17 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q4gsw_coop.yaml @@ -9,10 +9,12 @@ linear_q4gsw_coop: DTYPE: float IO_STORAGE: texture3d WEIGHT_STORAGE: texture2d + PACKED_INPUT_STORAGE: buffer TILE_M: 1 TILE_K4: 1 TILE_N8: 1 WGS: 64 + DYNAMIC_QUANT_VARIANT: false generate_variant_forall: DTYPE: - VALUE: float @@ -26,3 +28,16 @@ linear_q4gsw_coop: - NAME: linear_q4gsw_coop_buffer_buffer IO_STORAGE: buffer WEIGHT_STORAGE: buffer + - NAME: linear_dq8ca_q4gsw_coop_texture3d_texture2d + DYNAMIC_QUANT_VARIANT: true + - NAME: linear_dq8ca_q4gsw_coop_texture3d_buffer + WEIGHT_STORAGE: buffer + DYNAMIC_QUANT_VARIANT: true + - NAME: linear_dq8ca_q4gsw_coop_buffer_texture2d + IO_STORAGE: buffer + WEIGHT_STORAGE: texture2d + DYNAMIC_QUANT_VARIANT: true + - NAME: linear_dq8ca_q4gsw_coop_buffer_buffer + IO_STORAGE: buffer + WEIGHT_STORAGE: buffer + DYNAMIC_QUANT_VARIANT: true diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl index b6d932f0015..84be82f64e3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_q8csw_tiled.glsl @@ -99,10 +99,10 @@ void main() { FPPerOutChannelParams bias_tile; load_bias_tile(bias_tile, n4); - apply_scales_and_biases(out_tile, weight_scales_tile, bias_tile); + apply_weight_scales_and_biases(out_tile, weight_scales_tile, bias_tile); } else { - apply_scales(out_tile, weight_scales_tile); + apply_weight_scales(out_tile, weight_scales_tile); } if (dont_check_bounds) { diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl new file mode 100644 index 00000000000..792d8be908b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} +#define VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)} +#define T ${texel_load_component_type(DTYPE, INPUT_STORAGE)} + +$if OUTPUT_STORAGE == "buffer": + #define OUTPUT_BUFFER +$if INPUT_STORAGE == "buffer": + #define INPUT_BUFFER + +${define_required_extensions(DTYPE)} +${define_required_extensions("int8")} + +#extension GL_EXT_integer_dot_product : require + +#define NUM_GROUPS_PER_WG ${NUM_GROUPS_PER_WG} +#define NUM_WORKERS_PER_GROUP ${NUM_WORKERS_PER_GROUP} + +layout(std430) buffer; + +#include "common.glslh" + +${layout_declare_tensor(B, "w", "t_packed_int8_input", "int", OUTPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "w", "t_int8_input_sums", "int", "buffer", is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_input", DTYPE, INPUT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_int8_input_scales", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_int8_input_zps", "int8", "texture3d")} + +${layout_declare_ubo(B, "ivec4", "input_sizes")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "K4_per_group", "0")} + +shared ivec4 shared_sums[NUM_GROUPS_PER_WG][NUM_WORKERS_PER_GROUP]; + +#define TILE_M4 1 +#define TILE_K4 1 + +#define TILE_M 4 + +#include "linear_int8_input_block.glslh" +#include "linear_int8_input_scales_zps_load.glslh" +#include "linear_fp_input_tile_load.glslh" + +void main() { + const int group_idx = int(gl_GlobalInvocationID.x); + const int m4 = int(gl_GlobalInvocationID.y); + + const int worker_id = int(gl_LocalInvocationID.z); + const int group_offset = int(gl_LocalInvocationID.x); + + const int K = input_sizes.x; + const int M = input_sizes.y; + + // K4 and M4 represent the number of blocks in each dimension. + const int K4 = div_up_4(K); + const int M4 = div_up_4(M); + + const int num_groups = K4 / K4_per_group;; + + if (group_idx >= num_groups || m4 >= M4) { + return; + } + + const int start_k4 = group_idx * K4_per_group + worker_id; + const int end_k4 = (group_idx + 1) * K4_per_group; + + Int8InputScales input_scales; + Int8InputZeroPoints input_zps; + load_int8_input_scales_and_zps(input_scales, input_zps, m4); + + // row of the input tensor to start loading from + const int m = mul_4(m4); + + FPInputTile in_tile; + Int8InputBlock packed_block; + + ivec4 local_sum = ivec4(0, 0, 0, 0); + const int packed_ones = 0x01010101; + + for (int k4 = start_k4; k4 < end_k4; k4 += NUM_WORKERS_PER_GROUP) { + load_input_tile_no_checks(in_tile, k4, m, K4, M); + quantize_and_pack(packed_block, in_tile, input_scales, input_zps); + + // Sum the quantized values in the block + [[unroll]] for (int m = 0; m < TILE_M; m++) { + local_sum[m] += dotPacked4x8AccSatEXT( + packed_block.data[m], packed_ones, local_sum[m]); + } + write_block(packed_block, k4, m4, K4); + } + + shared_sums[group_offset][worker_id] = local_sum; + + memoryBarrierShared(); + barrier(); + + // Tree reduction to compute the overall result + for (int i = NUM_WORKERS_PER_GROUP / 2; i > 0; i >>= 1) { + if (worker_id < i) { + shared_sums[group_offset][worker_id] = + shared_sums[group_offset][worker_id] + + shared_sums[group_offset][worker_id + i]; + } + memoryBarrierShared(); + barrier(); + } + + if (worker_id == 0) { + t_int8_input_sums[group_idx * M4 + m4] = shared_sums[group_offset][0]; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml new file mode 100644 index 00000000000..3fc66db2718 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/quantize_and_pack_linear_input_with_sums.yaml @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +quantize_and_pack_linear_input_with_sums: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + INPUT_STORAGE: texture3d + NUM_GROUPS_PER_WG: 2 + NUM_WORKERS_PER_GROUP: 32 + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_texture3d + - NAME: quantize_and_pack_linear_input_with_sums_o2w32_buffer_buffer + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer + - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_texture3d + NUM_GROUPS_PER_WG: 4 + NUM_WORKERS_PER_GROUP: 16 + - NAME: quantize_and_pack_linear_input_with_sums_o4w16_buffer_buffer + NUM_GROUPS_PER_WG: 4 + NUM_WORKERS_PER_GROUP: 16 + OUTPUT_STORAGE: buffer + INPUT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp index 0d0be08bb38..a4a96ffdb88 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ChooseQParams.cpp @@ -186,16 +186,7 @@ vkapi::ShaderInfo pick_choose_qparams_per_row_shader( const ValueRef input = args.at(1).refs.at(0); - // number of output channels - const int64_t width = graph->size_at(-1, input); - const int64_t height = graph->size_at(-2, input); - std::string kernel_name = "choose_qparams_per_row"; - if (width > 256 || height == 1) { - kernel_name += "_o1w64"; - } else { - kernel_name += "_o4w16"; - } add_storage_type_suffix(kernel_name, graph->storage_type_of(input)); add_dtype_suffix(kernel_name, graph->dtype_of(input)); @@ -212,7 +203,7 @@ utils::uvec3 pick_choose_qparams_per_row_global_wg_size( const ValueRef input = args.at(1).refs.at(0); const uint32_t height = graph->size_at(-2, input); - return {1u, height, 1u}; + return {1u, utils::div_up_4(height), 1u}; } utils::uvec3 pick_choose_qparams_per_row_local_wg_size( @@ -228,11 +219,6 @@ utils::uvec3 pick_choose_qparams_per_row_local_wg_size( uint32_t outputs_per_wg = 1u; uint32_t workers_per_output = 64u; - if (shader.kernel_name.find("o4w16") != std::string::npos) { - outputs_per_wg = 4u; - workers_per_output = 16u; - } - return {workers_per_output, outputs_per_wg, 1u}; } diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp index 4831c6f2f85..6a50f81830c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear.cpp @@ -134,6 +134,89 @@ utils::uvec3 quant_pack_input_global_wg_size( 1u}; } +vkapi::ShaderInfo pick_quantize_and_pack_input_with_sums_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef packed_int_input = args.at(0).refs.at(0); + const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef group_size = resize_args.at(0); + + const int64_t group_size_val = graph->extract_scalar(group_size); + + std::string shader_name = "quantize_and_pack_linear_input_with_sums"; + if (group_size_val >= 128) { + shader_name += "_o2w32"; + } else { + shader_name += "_o4w16"; + } + + add_storage_type_suffix( + shader_name, graph->storage_type_of(packed_int_input)); + add_storage_type_suffix(shader_name, graph->storage_type_of(fp_input)); + add_dtype_suffix(shader_name, graph->dtype_of(fp_input)); + + return VK_KERNEL_FROM_STR(shader_name); +} + +utils::uvec3 pick_quantize_and_pack_input_with_sums_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv cases, skip the quantize and pack input step in favor of computing + // the quantized linear as a weight only quantized linear operation. The + // rationale for this is that gemv is a memory bound operation and may not + // necessarily benefit from quantizing the input and computing with integer + // accumulation. + if (is_gemv(graph, fp_input)) { + return {0u, 0u, 0u}; + } + + const ValueRef group_size = resize_args.at(0); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(*graph, fp_input); + + const int64_t group_size_val = graph->extract_scalar(group_size); + const int64_t blocks_per_group = group_size_val / 4; + + const int64_t num_groups = num_blocks_K / blocks_per_group; + + return { + utils::safe_downcast(num_groups), + utils::safe_downcast(num_blocks_M), + 1u}; +} + +utils::uvec3 pick_quantize_and_pack_input_with_sums_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + + const ValueRef fp_input = args.at(1).refs.at(0); + // For gemv, skip the quantize input step since the quantized linear is + // computed as a weight only quantized linear operation. + if (is_gemv(graph, fp_input)) { + return {1u, 1u, 1u}; + } + + uint32_t groups_per_wg = 2u; + uint32_t workers_per_group = 32u; + + if (shader.kernel_name.find("o4w16") != std::string::npos) { + groups_per_wg = 4u; + workers_per_group = 16u; + } + + return {groups_per_wg, 1u, workers_per_group}; +} + vkapi::ShaderInfo pick_linear_qw_shader( ComputeGraph* graph, const std::vector& args, @@ -167,6 +250,40 @@ vkapi::ShaderInfo pick_linear_qw_shader( return VK_KERNEL_FROM_STR(kernel_name); } +vkapi::ShaderInfo pick_linear_dqa_qw_shader( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)resize_args; + + const ValueRef out = args.at(0).refs.at(0); + const ValueRef fp_input = args.at(1).refs.at(0); + const ValueRef int_input = args.at(1).refs.at(1); + (void)int_input; + const ValueRef int_weight = args.at(1).refs.at(5); + + const bool weight_is_4bit = resize_args.at(0) != kDummyValueRef; + const bool is_gemv_case = is_gemv(graph, fp_input); + + std::string kernel_name = "linear_"; + if (weight_is_4bit) { + kernel_name += "dq8ca_q4gsw"; + } else { + kernel_name += "dq8ca_q8csw"; + } + + if (weight_is_4bit && is_gemv_case) { + kernel_name += "_coop"; + } else { + kernel_name += "_tiled"; + } + add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); + add_storage_type_suffix(kernel_name, graph->storage_type_of(int_weight)); + add_dtype_suffix(kernel_name, graph->dtype_of(out)); + + return VK_KERNEL_FROM_STR(kernel_name); +} + // // Prepacking nodes // @@ -409,6 +526,45 @@ DynamicDispatchNode make_quantize_and_pack_linear_input_node( {}); } +DynamicDispatchNode make_quantize_and_pack_linear_input_with_sums_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const ValueRef fp_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scales, + const ValueRef packed_input_zps, + const ValueRef packed_int_input, + const ValueRef group_size) { + // Only certain quantization types supported at the moment + VK_CHECK_COND(input_quant_config.granularity == kPerChannel); + + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); + + vkapi::ParamsBindList param_buffers = {graph.sizes_ubo(fp_input)}; + + const int32_t group_size_val = graph.extract_scalar(group_size); + const int32_t blocks_per_group = utils::div_up(group_size_val, int32_t(4)); + + return DynamicDispatchNode( + graph, + pick_quantize_and_pack_input_with_sums_shader, + pick_quantize_and_pack_input_with_sums_global_wg_size, + pick_quantize_and_pack_input_with_sums_local_wg_size, + // Inputs and Outputs + {{{packed_int_input, int_input_sums}, vkapi::kWrite}, + {{fp_input, packed_input_scales, packed_input_zps}, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {blocks_per_group}, + // Resize args + {group_size}); +} + DynamicDispatchNode make_linear_qa_qw_node( ComputeGraph& graph, const QuantizationConfig& input_quant_config, @@ -483,6 +639,80 @@ DynamicDispatchNode make_linear_qa_qw_node( nullptr); } +DynamicDispatchNode make_linear_dqa_qw_node( + ComputeGraph& graph, + const QuantizationConfig& input_quant_config, + const QuantizationConfig& weight_quant_config, + const ValueRef fp_input, + const ValueRef packed_int_input, + const ValueRef int_input_sums, + const ValueRef packed_input_scale, + const ValueRef packed_input_zp, + const ValueRef input_scale_data, + const ValueRef input_zp_data, + const ValueRef weight_data, + const ValueRef packed_weight, + const ValueRef packed_weight_sums, + const ValueRef packed_weight_scales, + const ValueRef group_size, + const ValueRef bias_data, + const ValueRef packed_bias, + const ValueRef output) { + VK_CHECK_COND(input_quant_config.granularity == kPerChannel); + VK_CHECK_COND(input_quant_config.nbits == 8); + VK_CHECK_COND(input_quant_config.is_dynamic); + + VK_CHECK_COND(weight_quant_config.granularity == kPerGroup); + VK_CHECK_COND(weight_quant_config.is_symmetric); + VK_CHECK_COND(weight_quant_config.nbits == 4); + + vkapi::ParamsBindList param_buffers = { + graph.sizes_ubo(output), graph.sizes_ubo(fp_input)}; + + uint32_t apply_bias = 1; + if (graph.val_is_none(bias_data)) { + apply_bias = 0; + } + + int32_t K4_per_group = 0; + if (weight_quant_config.nbits == 4) { + int32_t group_size_val = graph.extract_scalar(group_size); + K4_per_group = utils::div_up(group_size_val, int32_t(4)); + } + + const ValueRef is_4bit_flag = + weight_quant_config.nbits == 4 ? group_size : kDummyValueRef; + + // Add the compute node + return DynamicDispatchNode( + graph, + pick_linear_dqa_qw_shader, + quantized_linear_global_wg_size, + quantized_linear_local_wg_size, + // Inputs and Outputs + {{output, vkapi::kWrite}, + {{fp_input, + packed_int_input, + int_input_sums, + packed_input_scale, + packed_input_zp, + packed_weight, + packed_weight_sums, + packed_weight_scales, + packed_bias}, + vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + {}, + // Specialization Constants + {apply_bias, K4_per_group}, + // Resize args + {is_4bit_flag, weight_data}, + // Resizing Logic + resize_linear_qw_node); +} + // // High level operator impl // @@ -553,42 +783,44 @@ void quantized_linear_impl( graph.execute_nodes().emplace_back(new DynamicDispatchNode(linear_qw_node)); return; - } else { - // Otherwise, use input and weight quantized linear computed with integer - // accumulation - - // Input scale/zero point only used for activation & weight quantized linear - ValueRef packed_input_scale = input_scale; - ValueRef packed_input_zp = input_zp; - if (graph.val_is_tref(input_scale)) { - VK_CHECK_COND(graph.val_is_tref(packed_input_zp)); - packed_input_scale = prepack_standard( - graph, input_scale, utils::kBuffer, utils::kWidthPacked); - packed_input_zp = prepack_standard( - graph, input_zp, utils::kBuffer, utils::kWidthPacked); - } + } + // Otherwise, use input and weight quantized linear computed with integer + // accumulation + + // Input scale/zero point only used for activation & weight quantized linear + ValueRef packed_input_scale = input_scale; + ValueRef packed_input_zp = input_zp; + if (graph.val_is_tref(input_scale)) { + VK_CHECK_COND(graph.val_is_tref(packed_input_zp)); + packed_input_scale = prepack_standard( + graph, input_scale, utils::kTexture3D, utils::kWidthPacked); + packed_input_zp = prepack_standard( + graph, input_zp, utils::kTexture3D, utils::kWidthPacked); + } - // Pre-computed per quant group weight sums are needed for int accumulation, - // but not for weight only - const ValueRef packed_weight_sums = prepack_standard( - graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); + // Pre-computed per quant group weight sums are needed for int accumulation, + // but not for weight only + const ValueRef packed_weight_sums = prepack_standard( + graph, weight_sums_data, utils::kBuffer, utils::kWidthPacked); - // Allocate temporary tensor to store quantized and packed input + // Allocate temporary tensor to store quantized and packed input - int64_t num_blocks_M, num_blocks_K; - std::tie(num_blocks_M, num_blocks_K) = - get_quantized_input_num_blocks(graph, fp_input); + int64_t num_blocks_M, num_blocks_K; + std::tie(num_blocks_M, num_blocks_K) = + get_quantized_input_num_blocks(graph, fp_input); - const int64_t int_input_height = num_blocks_M; - const int64_t int_input_width = num_blocks_K * 4; + const int64_t int_input_height = num_blocks_M; + const int64_t int_input_width = num_blocks_K * 4; - TmpTensor packed_int_input( - &graph, - {int_input_height, int_input_width}, - vkapi::kInt, - utils::kBuffer, - utils::kWidthPacked); + TmpTensor packed_int_input( + &graph, + {int_input_height, int_input_width}, + vkapi::kInt, + utils::kBuffer, + utils::kWidthPacked); + // Non dynamically quantized input case + if (!input_quant_config.is_dynamic) { DynamicDispatchNode quantize_and_pack_linear_node( make_quantize_and_pack_linear_input_node( graph, @@ -625,7 +857,62 @@ void quantized_linear_impl( graph.execute_nodes().emplace_back( new DynamicDispatchNode(linear_qa_qw_node)); + + return; + } + + // Otherwise, input is dynamically quantized. Currently only per group 4-bit + // quantized weights is supported for this mode. + VK_CHECK_COND(weight_quant_config.nbits == 4); + + int64_t num_groups = 1; + if (weight_quant_config.granularity == kPerGroup) { + num_groups = graph.size_at(-2, weight_scales_data); } + + TmpTensor int_input_sums( + &graph, + {num_groups, K}, + graph.dtype_of(output), + utils::kBuffer, + utils::kWidthPacked); + + DynamicDispatchNode quantize_and_pack_input_with_sums_node( + make_quantize_and_pack_linear_input_with_sums_node( + graph, + input_quant_config, + fp_input, + int_input_sums, + packed_input_scale, + packed_input_zp, + packed_int_input, + group_size)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(quantize_and_pack_input_with_sums_node)); + + DynamicDispatchNode linear_dqa_qw_node(make_linear_dqa_qw_node( + graph, + input_quant_config, + weight_quant_config, + fp_input, + packed_int_input, + int_input_sums, + packed_input_scale, + packed_input_zp, + input_scale, + input_zp, + weight_data, + packed_weight, + packed_weight_sums, + packed_weight_scales, + group_size, + bias_data, + packed_bias, + output)); + + graph.execute_nodes().emplace_back( + new DynamicDispatchNode(linear_dqa_qw_node)); } void linear_q8ta_q8csw(ComputeGraph& graph, const std::vector& args) { @@ -719,10 +1006,46 @@ void linear_q4gsw(ComputeGraph& graph, const std::vector& args) { output); } +void linear_dq8ca_q4gsw( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef input_scale = args.at(idx++); + const ValueRef input_zp = args.at(idx++); + const ValueRef weight_data = args.at(idx++); + const ValueRef weight_sums_data = args.at(idx++); + const ValueRef weight_scales_data = args.at(idx++); + const ValueRef group_size = args.at(idx++); + const ValueRef bias_data = args.at(idx++); + const ValueRef output = args.at(idx++); + + const int64_t group_size_val = graph.extract_scalar(group_size); + + QuantizationConfig input_quant_config(8, kPerChannel, {}, false, true); + QuantizationConfig weight_quant_config(4, kPerGroup, {group_size_val}); + + quantized_linear_impl( + graph, + input_quant_config, + weight_quant_config, + fp_input, + input_scale, + input_zp, + weight_data, + weight_sums_data, + weight_scales_data, + kDummyValueRef, // weight_zeros_data + group_size, // group_size + bias_data, + output); +} + REGISTER_OPERATORS { VK_REGISTER_OP(et_vk.linear_q8ta_q8csw.default, linear_q8ta_q8csw); VK_REGISTER_OP(et_vk.linear_q8csw.default, linear_q8csw); VK_REGISTER_OP(et_vk.linear_q4gsw.default, linear_q4gsw); + VK_REGISTER_OP(et_vk.linear_dq8ca_q4gsw.default, linear_dq8ca_q4gsw); } } // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp b/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp index aa2b21feab8..f984b2b2b40 100644 --- a/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp +++ b/backends/vulkan/test/custom_ops/choose_qparams_per_row.cpp @@ -71,7 +71,7 @@ TestCase create_test_case_from_config( ValueSpec scale_out( {config.num_channels}, vkapi::kFloat, - utils::kBuffer, // Always buffer as per requirement + utils::kTexture3D, // Always buffer as per requirement utils::kWidthPacked, DataGenType::ZEROS); @@ -79,7 +79,7 @@ TestCase create_test_case_from_config( ValueSpec zero_point_out( {config.num_channels}, vkapi::kChar, // int8 for quantized zero point - utils::kBuffer, // Always buffer as per requirement + utils::kTexture3D, // Always buffer as per requirement utils::kWidthPacked, DataGenType::ZEROS); diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index 805b67c30a2..a6adfdc2d5c 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -26,7 +26,7 @@ struct LinearConfig { int64_t group_size; // Number of input channels per quantization group bool has_bias = false; std::string test_case_name = "placeholder"; - std::string op_name = "linear_q4gsw"; + std::string op_name = "linear_dq8ca_q4gsw"; }; // Helper function to unpack 4-bit values from uint8 @@ -70,12 +70,30 @@ TestCase create_test_case_from_config( input_dtype, storage_type, utils::kWidthPacked, - DataGenType::RANDOM); + DataGenType::RANDINT); if (debugging()) { print_valuespec_data(input_tensor, "input_tensor"); } + // For activation+weight quantized linear (linear_dq8ca_q4gsw) + // Input scale and zero point as per-input channel tensors + ValueSpec input_scale( + {1, config.M}, // Per-input channel tensor + input_dtype, + storage_type, + utils::kWidthPacked, + DataGenType::RANDOM_SCALES); + input_scale.set_constant(true); + + ValueSpec input_zero_point( + {1, config.M}, // Per-input channel tensor + vkapi::kChar, + storage_type, + utils::kWidthPacked, + DataGenType::RANDINT); + input_zero_point.set_constant(true); + // For 4-bit weights, packed size is [N, K/2] since 2 weights per byte std::vector weight_size = {config.N, config.K / 2}; // Quantized weight tensor (uint8, packed 4-bit) - [N, K/2] @@ -105,6 +123,22 @@ TestCase create_test_case_from_config( DataGenType::RANDOM_SCALES); weight_scales.set_constant(true); + // Pre-computed per-group weight sums for zero point adjustment + // This is needed for activation+weight quantized operations + // Size: [K/group_size, N] - one sum per group per output feature + ValueSpec weight_sums( + weight_scales_size, // Same size as weight_scales + vkapi::kInt, + storage_type, + utils::kWidthPacked, + DataGenType::ZEROS); + weight_sums.set_constant(true); + + // Compute weight_sums data based on quantized weights + int64_t num_groups = config.K / config.group_size; + compute_weight_sums_4bit_grouped( + weight_sums, quantized_weight, num_groups, config.N, config.group_size); + // Group size parameter ValueSpec group_size_spec(static_cast(config.group_size)); @@ -128,13 +162,27 @@ TestCase create_test_case_from_config( utils::kWidthPacked, DataGenType::ZEROS); - // Add all specs to test case for linear_q4gsw - test_case.add_input_spec(input_tensor); - test_case.add_input_spec(quantized_weight); - test_case.add_input_spec(weight_scales); - test_case.add_input_spec(group_size_spec); - test_case.add_input_spec(bias); - test_case.add_output_spec(output); + // Add all specs to test case based on operator type + if (config.op_name.find("dq8ca") != std::string::npos) { + // For activation+weight quantized linear (linear_dq8ca_q4gsw) + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(input_scale); + test_case.add_input_spec(input_zero_point); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_sums); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(group_size_spec); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } else { + // For weight-only quantized linear (linear_q4gsw) + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(quantized_weight); + test_case.add_input_spec(weight_scales); + test_case.add_input_spec(group_size_spec); + test_case.add_input_spec(bias); + test_case.add_output_spec(output); + } return test_case; } @@ -144,8 +192,8 @@ std::vector generate_quantized_linear_easy_cases() { std::vector test_cases; // Single simple configuration for debugging - int M = 4; - int K = 32; + int M = 8; + int K = 16; int N = 16; int group_size = 8; @@ -179,9 +227,9 @@ std::vector generate_quantized_linear_test_cases() { std::vector test_cases; std::vector configs = { - // Gemv test cases - {1, 128, 64, 32}, - {1, 256, 128, 64}, + // // Gemv test cases + // {1, 128, 64, 32}, + // {1, 256, 128, 64}, // Gemm {4, 64, 32, 16}, {4, 128, 64, 32}, @@ -194,9 +242,9 @@ std::vector generate_quantized_linear_test_cases() { {32, 256, 128, 64, false}, // Performance test cases {1, 2048, 2048, 128}, - {128, 2048, 2048, 128}, - {256, 2048, 2048, 128}, - {1024, 2048, 2048, 128}, + {128, 2048, 2048, 256}, + {256, 2048, 2048, 256}, + {1024, 2048, 2048, 256}, }; // Test with different storage types and data types @@ -219,8 +267,14 @@ std::vector generate_quantized_linear_test_cases() { config.test_case_name = generated_test_case_name; for (const auto& storage_type : storage_types) { + // Test both activation+weight quantized and weight only quantized test_cases.push_back( create_test_case_from_config(config, storage_type, vkapi::kFloat)); + + LinearConfig wo_quant_config = config; + wo_quant_config.op_name = "linear_q4gsw"; + test_cases.push_back(create_test_case_from_config( + wo_quant_config, storage_type, vkapi::kFloat)); } } @@ -315,11 +369,138 @@ void linear_q4gsw_reference_impl(TestCase& test_case) { } } +// Reference implementation for activation+weight quantized linear (dq8ca_q4gsw) +void linear_dq8ca_q4gsw_reference_impl(TestCase& test_case) { + // Extract input specifications + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; + const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; + const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; + const ValueSpec& group_size_spec = test_case.inputs()[idx++]; + const ValueSpec& bias_spec = test_case.inputs()[idx++]; + + // Extract output specification (mutable reference) + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions + auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] + auto weight_sizes = + weight_spec.get_tensor_sizes(); // [out_features, in_features/2] + auto output_sizes = + output_spec.get_tensor_sizes(); // [batch_size, out_features] + + int64_t batch_size = input_sizes[0]; + int64_t in_features = input_sizes[1]; + int64_t out_features = output_sizes[1]; + int64_t group_size = group_size_spec.get_int_value(); + + // Skip for large tensors since computation time will be extremely slow + if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || + out_features > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + auto& input_scale_data = + input_scale_spec.get_float_data(); // Per-input channel tensor + auto& input_zero_point_data = + input_zeros_spec.get_int8_data(); // Per-input channel tensor + + auto& weight_data = weight_spec.get_uint8_data(); + auto& weight_sums_data = weight_sums_spec.get_int32_data(); + (void)weight_sums_data; // Unused for now + auto& weight_scales_data = weight_scales_spec.get_float_data(); + auto& bias_data = bias_spec.get_float_data(); + + // Calculate number of output elements + int64_t num_output_elements = batch_size * out_features; + + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_output_elements); + + // Perform quantized linear transformation (matrix multiplication) with + // integer accumulation + for (int64_t b = 0; b < batch_size; ++b) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t int_sum = 0; + (void)int_sum; + int32_t weight_sum = 0; // Track weight sum on the fly for each group + (void)weight_sum; + + // For group symmetric quantization, compute with proper grouping for + // accurate reference + float float_result = 0.0f; + + for (int64_t in_f = 0; in_f < in_features; ++in_f) { + // Get input value and quantize to int8 using per-input channel + // parameters + int64_t input_idx = b * in_features + in_f; + + // Use per-input channel scale and zero point - index by batch dimension + float input_scale = input_scale_data[b]; // {1, M} -> index by batch + int8_t input_zero_point = + input_zero_point_data[b]; // {1, M} -> index by batch + + float quant_input_f = + std::round(input_data[input_idx] / input_scale) + input_zero_point; + quant_input_f = std::min(std::max(quant_input_f, -128.0f), 127.0f); + int8_t quantized_input = static_cast(quant_input_f); + + // Get quantized weight and its scale + int64_t weight_idx = out_f * (in_features / 2) + (in_f / 2); + uint8_t packed_weight = weight_data[weight_idx]; + auto unpacked = unpack_4bit(packed_weight); + int8_t quantized_weight = + (in_f % 2 == 0) ? unpacked.first : unpacked.second; + + // Get the appropriate scale for this group + int64_t group_idx = in_f / group_size; + int64_t scales_idx = group_idx * out_features + out_f; + float weight_scale = weight_scales_data[scales_idx]; + + // Compute the contribution with proper scaling + float contribution = + static_cast(quantized_input - input_zero_point) * + static_cast(quantized_weight) * input_scale * weight_scale; + + float_result += contribution; + } + + // Add bias and store result + if (!bias_spec.is_none()) { + float_result += bias_data[out_f]; + } + int64_t output_idx = b * out_features + out_f; + ref_data[output_idx] = float_result; + } + } +} + void reference_impl(TestCase& test_case) { - linear_q4gsw_reference_impl(test_case); + if (test_case.operator_name().find("dq8ca") != std::string::npos) { + linear_dq8ca_q4gsw_reference_impl(test_case); + } else { + linear_q4gsw_reference_impl(test_case); + } } int64_t quantized_linear_flop_calculator(const TestCase& test_case) { + int input_idx = 0; + int weight_idx = 1; + if (test_case.operator_name().find("dq8ca") != std::string::npos) { + input_idx = 0; + weight_idx = 3; // Weight comes after input, input_scale, input_zero_point + } + // Get input and weight dimensions const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); @@ -339,6 +520,7 @@ int64_t quantized_linear_flop_calculator(const TestCase& test_case) { // - Unpack 4-bit weight: 1 op per weight element used // - Dequantize weight: 1 op per weight element used // - Add bias: 1 op per output element + // - For activation+weight quantized: add input quantization ops int64_t quantization_ops = ops_per_output * 2 + 1; // Simplified estimate int64_t flop = output_elements * (ops_per_output + quantization_ops); @@ -365,7 +547,7 @@ int main(int argc, char* argv[]) { generate_quantized_linear_test_cases, quantized_linear_flop_calculator, "QuantizedLinearQ4GSW", - 0, + 10, 10, ref_fn); diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 37e0060b3f2..2aa827a4d5a 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -213,6 +213,9 @@ void ValueSpec::generate_tensor_data() { generate_random_2xint4_data(uint8_data); } else if (data_gen_type == DataGenType::ONES) { std::fill(uint8_data.begin(), uint8_data.end(), 1); + } else if (data_gen_type == DataGenType::ONES_INT4) { + uint8_t packed_data = (9 << 4) | 9; + std::fill(uint8_data.begin(), uint8_data.end(), packed_data); } else if (data_gen_type == DataGenType::ZEROS) { std::fill(uint8_data.begin(), uint8_data.end(), 0); } else { @@ -1712,6 +1715,62 @@ void compute_weight_sums( } } +// Helper function to unpack 4-bit values from uint8 (same as in +// q4gsw_linear.cpp) +std::pair unpack_4bit_utils(uint8_t packed) { + // Extract lower 4 bits and upper 4 bits + int8_t lower = packed & 0x0F; + int8_t upper = (packed >> 4) & 0x0F; + + // Subtract 8 from unpacked 4-bit values + lower -= 8; + upper -= 8; + + return std::make_pair(lower, upper); +} + +// Compute weight sums for 4-bit group symmetric quantized weights +void compute_weight_sums_4bit_grouped( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t num_groups, + int64_t out_features, + int64_t group_size) { + auto& weight_sums_data = weight_sums.get_int32_data(); + auto& quantized_weight_data = quantized_weight.get_uint8_data(); + + // Resize to [num_groups, out_features] + weight_sums_data.resize(num_groups * out_features); + + // For each group and each output feature, compute the sum of quantized + // weights in that group + for (int64_t group_idx = 0; group_idx < num_groups; ++group_idx) { + for (int64_t out_f = 0; out_f < out_features; ++out_f) { + int32_t sum = 0; + + // Sum weights for this group and output feature + for (int64_t in_group = 0; in_group < group_size; ++in_group) { + int64_t in_f = group_idx * group_size + in_group; + + // Get packed weight value - weight matrix is [N, K/2] + int64_t weight_idx = + out_f * ((num_groups * group_size) / 2) + (in_f / 2); + uint8_t packed_weight = quantized_weight_data[weight_idx]; + + // Unpack 4-bit weight + auto unpacked = unpack_4bit_utils(packed_weight); + int8_t weight_4bit = (in_f % 2 == 0) ? unpacked.first : unpacked.second; + + sum += static_cast(weight_4bit); + } + + // Store sum for this group and output feature + int64_t sums_idx = group_idx * out_features + out_f; + weight_sums_data[sums_idx] = sum; + } + } +} + } // namespace prototyping } // namespace vulkan } // namespace executorch diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 2440e225ef2..f1736f1d144 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -653,6 +653,14 @@ void compute_weight_sums( int64_t out_features, int64_t elements_per_output_feature); +// Compute weight sums for 4-bit group symmetric quantized weights +void compute_weight_sums_4bit_grouped( + ValueSpec& weight_sums, + const ValueSpec& quantized_weight, + int64_t num_groups, + int64_t out_features, + int64_t group_size); + // Setup compute graph based on TestCase and operation name ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name); diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index 00a357b0b67..a832915359f 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2650,3 +2650,85 @@ def forward(self, x): atol=1e-2, rtol=1e-1, ) + + def test_vulkan_backend_torchao_8da4w_quantized_linear(self): + """ + Test TorchAO 8da4w quantization (int8 dynamic activation + int4 weight) with Vulkan backend. + This test uses the same quantization approach as the 8da4w qmode in quantize.py. + """ + in_features = 1024 + out_features = 512 + bias = False + group_size = 128 + + class TorchAO8da4wQuantizedLinearModule(torch.nn.Module): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + group_size: int = 128, + ): + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, bias=bias) + self.group_size = group_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def apply_8da4w_quantization(self): + """Apply TorchAO 8da4w quantization (int8 dynamic activation + int4 weight).""" + from torchao.quantization import ( + int8_dynamic_activation_int4_weight, + quantize_, + ) + from torchao.utils import unwrap_tensor_subclass + + quantize_( + self, + int8_dynamic_activation_int4_weight(group_size=self.group_size), + ) + unwrap_tensor_subclass(self) + return self + + # Test with GEMV pattern (batch_size=1, seq_len=1) + quantized_linear_module = TorchAO8da4wQuantizedLinearModule( + in_features=in_features, + out_features=out_features, + bias=bias, + group_size=group_size, + ) + + # Apply 8da4w quantization + quantized_linear_module = quantized_linear_module.apply_8da4w_quantization() + + # Test with 2D input (GEMV pattern) + sample_inputs = (torch.randn(size=(1, in_features), dtype=torch.float32),) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + quantized_linear_module, sample_inputs, atol=1e-2, rtol=1e-2 + ) + + # Test with GEMM pattern (batch_size > 1) + quantized_linear_module_gemm = TorchAO8da4wQuantizedLinearModule( + in_features=in_features, + out_features=out_features, + bias=bias, + group_size=group_size, + ) + + # Apply 8da4w quantization + quantized_linear_module_gemm = ( + quantized_linear_module_gemm.apply_8da4w_quantization() + ) + + # Test with 3D input (GEMM pattern) + sample_inputs_gemm = ( + torch.randn(size=(1, 248, in_features), dtype=torch.float32), + ) + + # Use higher tolerance since quantization introduces some error + self.lower_module_and_test_output( + quantized_linear_module_gemm, sample_inputs_gemm, atol=1e-2, rtol=1e-2 + ) diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1291eb62936..474c6386a9e 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -84,6 +84,13 @@ def is_quant_node(node: torch.fx.Node) -> bool: return node_name in _Q_OPS +def is_choose_qparams_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return "choose_qparams" in node_name + + def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -91,6 +98,13 @@ def is_dequant_per_channel_node(node: torch.fx.Node) -> bool: return node_name == "dequantize_per_channel.default" +def is_view_copy_node(node: torch.fx.Node) -> bool: + if node.op != "call_function": + return False + node_name = format_target_name(node.target.__name__) # pyre-ignore + return "view_copy" in node_name + + def is_linear_node(node: torch.fx.Node) -> bool: if node.op != "call_function": return False @@ -1126,8 +1140,15 @@ def maybe_skip_q_dq_arg_chain( if not isinstance(arg, torch.fx.Node): return None, None, None - if is_dequant_node(arg): - dequant_node = arg + # If the arg is a view copy node, check if the original node is a dequant node + if is_dequant_node(arg) or ( + is_view_copy_node(arg) and is_dequant_node(arg.args[0]) + ): + if is_view_copy_node(arg): + dequant_node = arg.args[0] + else: + dequant_node = arg + quant_node = dequant_node.args[0] assert isinstance(quant_node, torch.fx.Node) source_arg = quant_node.args[0] From 994ca6dba0646245baa8ff6d499b7e6e547429f4 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 09:36:55 -0700 Subject: [PATCH 2/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- .github/workflows/pull.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 455f427b386..fa566afca85 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -933,7 +933,7 @@ jobs: PYTHON_EXECUTABLE=python bash backends/vulkan/test/custom_ops/build_and_run.sh add ./cmake-out/backends/vulkan/test/custom_ops/q8csw_linear ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d - ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_conv2d + ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row # Run e2e testing for selected operators. More operators will be tested via this From 06ffc8c73245adc48d5d084dcf30ce52f7a46111 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 10:38:14 -0700 Subject: [PATCH 3/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- backends/vulkan/patterns/quantized_linear.py | 2 - .../vulkan/test/custom_ops/CMakeLists.txt | 4 +- .../test/custom_ops/quantized_int4_linear.cpp | 366 --------------- .../custom_ops/quantized_q4gaw_linear.cpp | 433 ------------------ 4 files changed, 2 insertions(+), 803 deletions(-) delete mode 100644 backends/vulkan/test/custom_ops/quantized_int4_linear.cpp delete mode 100644 backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index fd2708327c9..2858cda111a 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -116,7 +116,6 @@ def __init__(self, mm_node: torch.fx.Node) -> None: # If input is not quantized, then we are done if self.quantize_input_node is None: - raise Exception("Input is not quantized") self.match_found = True return @@ -478,7 +477,6 @@ def replace_quantized_linear_patterns( and match.is_weight_pergroup_quantized() and utils.is_in_4bit_range(weight_tensor) ): - raise Exception("Unsupported pattern") make_linear_q4gsw_op( ep, graph_module, match, weight_tensor, weight_scales_tensor ) diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index fe58055f649..97b632338db 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -92,7 +92,7 @@ if(TARGET vulkan_backend) # Define operator prototypes add_operator_prototype(add) add_operator_prototype(q8csw_linear) - add_operator_prototype(quantized_q4gaw_linear) - add_operator_prototype(quantized_int4_linear) add_operator_prototype(q8csw_conv2d) + add_operator_prototype(q4gsw_linear) + add_operator_prototype(choose_qparams_per_row) endif() diff --git a/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp b/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp deleted file mode 100644 index c125ce2d09c..00000000000 --- a/backends/vulkan/test/custom_ops/quantized_int4_linear.cpp +++ /dev/null @@ -1,366 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include "utils.h" - -#include - -using namespace executorch::vulkan::prototyping; - -using namespace vkcompute; - -// Linear configuration struct -struct LinearConfig { - int64_t M; // Batch size / number of rows in input - int64_t K; // Input features / columns in input, rows in weight - int64_t N; // Output features / columns in weight - int64_t group_size; // Number of input channels per quantization group - std::string name_suffix; - std::string shader_variant_name = "default"; -}; - -// Utility function to create a test case from a LinearConfig -TestCase create_test_case_from_config( - const LinearConfig& config, - utils::StorageType storage_type, - vkapi::ScalarType input_dtype) { - TestCase test_case; - - // Create a descriptive name for the test case - std::string storage_str = - (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; - std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; - - std::string test_name = "QuantizedLinearInt4_" + config.name_suffix + "_" + - storage_str + "_" + dtype_str; - test_case.set_name(test_name); - - // Set the operator name for the test case - std::string operator_name = "et_vk.linear_weight_int4.default"; - test_case.set_operator_name(operator_name); - - // Derive sizes from M, K, N - std::vector input_size = {config.M, config.K}; - std::vector weight_size = { - config.N, config.K / 2}; // Packed 4-bit weights - - // Input tensor (float/half) - [M, K] - ValueSpec input_tensor( - input_size, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::ONES); - - if (debugging()) { - print_valuespec_data(input_tensor, "input_tensor"); - } - - // Quantized weight tensor (int8, packed 4-bit) - [N, K/2] - ValueSpec quantized_weight( - weight_size, - vkapi::kChar, // int8 for packed 4-bit quantized weights - storage_type, - utils::kWidthPacked, - DataGenType::ONES); - quantized_weight.set_constant(true); - quantized_weight.set_int4(true); - - if (debugging()) { - print_valuespec_data(quantized_weight, "weight_tensor"); - } - - // Group size parameter - ValueSpec group_size_spec(static_cast(config.group_size)); - - // Weight quantization scales and zeros (float/half, per-group) - - // [K/group_size, N, 2] - std::vector scales_and_zeros_size = { - config.K / config.group_size, config.N, 2}; - ValueSpec scales_and_zeros( - scales_and_zeros_size, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::ONES); - scales_and_zeros.set_constant(true); - - if (debugging()) { - print_valuespec_data(scales_and_zeros, "scales_and_zeros"); - } - - // Output tensor (float/half) - [M, N] - ValueSpec output( - {config.M, config.N}, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::ZEROS); - - // Add all specs to test case - test_case.add_input_spec(input_tensor); - test_case.add_input_spec(quantized_weight); - test_case.add_input_spec(group_size_spec); - test_case.add_input_spec(scales_and_zeros); - // Add dummy value for inner_k_tiles (unused but required by operator - // signature) - ValueSpec dummy_inner_k_tiles(static_cast(8)); - test_case.add_input_spec(dummy_inner_k_tiles); - - test_case.add_output_spec(output); - - return test_case; -} - -// Generate easy test cases for quantized linear operation (for debugging) -std::vector generate_quantized_linear_easy_cases() { - std::vector test_cases; - - // Single simple configuration for debugging - int M = 8; - int K = 16; - int N = 16; - int group_size = 8; - - LinearConfig config = { - M, // Batch size - K, // Input features - N, // Output features - group_size, // Group size - "simple", // descriptive name - "default" // shader variant name - }; - - // Test with both storage types and data types for completeness - std::vector storage_types = { - utils::kTexture3D, utils::kBuffer}; - std::vector float_types = {vkapi::kFloat}; - - // Generate test cases for each combination - for (const auto& storage_type : storage_types) { - for (const auto& input_dtype : float_types) { - test_cases.push_back( - create_test_case_from_config(config, storage_type, input_dtype)); - } - } - - return test_cases; -} - -// Generate test cases for quantized linear operation -std::vector generate_quantized_linear_test_cases() { - std::vector test_cases; - - std::vector configs = { - {8, 64, 32, 8, "correctness_8_64_32_g8"}, - {8, 128, 64, 16, "correctness_8_128_64_g16"}, - {8, 256, 128, 32, "correctness_8_256_128_g32"}, - {32, 64, 32, 8, "correctness_32_64_32_g8"}, - {32, 128, 64, 16, "correctness_32_128_64_g16"}, - {32, 256, 128, 32, "correctness_32_256_128_g32"}, - {1, 256, 128, 32, "correctness_32_256_128_g32"}, - // Performance test cases - {1, 2048, 2048, 128, "performance_128_2048_2048_g128"}, - {128, 2048, 2048, 128, "performance_128_2048_2048_g128"}, - {248, 2048, 2048, 128, "performance_128_2048_2048_g128"}, - {1024, 2048, 2048, 128, "performance_128_2048_2048_g128"}, - // {16384, 576, 128, 32, "performance_16384_576_128_g32"} - }; - - // Test with different storage types and data types - std::vector storage_types = { - utils::kTexture3D, utils::kBuffer}; - - // Generate test cases for each combination - for (const auto& config : configs) { - for (const auto& storage_type : storage_types) { - test_cases.push_back( - create_test_case_from_config(config, storage_type, vkapi::kFloat)); - } - } - - return test_cases; -} - -// Helper function to unpack 4-bit values from int8 -std::pair unpack_4bit(int8_t packed) { - // Extract lower 4 bits and upper 4 bits - int8_t lower = packed & 0x0F; - int8_t upper = (packed >> 4) & 0x0F; - - // Sign extend from 4-bit to 8-bit - if (lower & 0x08) - lower |= 0xF0; - if (upper & 0x08) - upper |= 0xF0; - - return std::make_pair(lower, upper); -} - -// Reference implementation for quantized linear operation -void quantized_linear_reference_impl(TestCase& test_case) { - static constexpr int64_t kRefDimSizeLimit = 300; - // Extract input specifications - int32_t idx = 0; - const ValueSpec& input_spec = test_case.inputs()[idx++]; - const ValueSpec& weight_spec = test_case.inputs()[idx++]; - const ValueSpec& group_size_spec = test_case.inputs()[idx++]; - const ValueSpec& scales_and_zeros_spec = test_case.inputs()[idx++]; - // Skip dummy inner_k_tiles - idx++; - - // Extract output specification (mutable reference) - ValueSpec& output_spec = test_case.outputs()[0]; - - // Get tensor dimensions - auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] - auto weight_sizes = - weight_spec.get_tensor_sizes(); // [out_features, in_features/2] - auto output_sizes = - output_spec.get_tensor_sizes(); // [batch_size, out_features] - - int64_t batch_size = input_sizes[0]; - int64_t in_features = input_sizes[1]; - int64_t out_features = output_sizes[1]; - int64_t group_size = group_size_spec.get_int_value(); - - // Skip for large tensors since computation time will be extremely slow - if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || - out_features > kRefDimSizeLimit) { - throw std::invalid_argument( - "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); - } - - if (input_spec.dtype != vkapi::kFloat) { - throw std::invalid_argument("Unsupported dtype"); - } - - // Get raw data pointers - auto& input_data = input_spec.get_float_data(); - auto& weight_data = weight_spec.get_int8_data(); - auto& scales_and_zeros_data = scales_and_zeros_spec.get_float_data(); - - // Calculate number of output elements - int64_t num_output_elements = batch_size * out_features; - - auto& ref_data = output_spec.get_ref_float_data(); - ref_data.resize(num_output_elements); - - // Perform quantized linear transformation (matrix multiplication) - for (int64_t b = 0; b < batch_size; ++b) { - for (int64_t out_f = 0; out_f < out_features; ++out_f) { - float sum = 0.0f; - - bool should_print = b == 0 && out_f == 0; - should_print = false; - - if (should_print) { - std::cout << "Weights seen: "; - } - - // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * - // weight[out_f][in_f]) - for (int64_t in_f = 0; in_f < in_features; ++in_f) { - // Get input value - int64_t input_idx = b * in_features + in_f; - float input_val = input_data[input_idx]; - - // Get weight value and dequantize (4-bit group affine quantization) - int64_t group_idx = in_f / group_size; - int64_t scales_and_zeros_idx = group_idx * out_features * 2 + out_f * 2; - - // Get packed weight value - int64_t weight_idx = out_f * (in_features / 2) + (in_f / 2); - int8_t packed_weight = weight_data[weight_idx]; - - // Unpack 4-bit weight - auto unpacked = unpack_4bit(packed_weight); - int8_t weight_4bit = (in_f % 2 == 0) ? unpacked.first : unpacked.second; - - // Dequantize weight using group affine quantization - float weight_scale = scales_and_zeros_data[scales_and_zeros_idx]; - float weight_zero = scales_and_zeros_data[scales_and_zeros_idx + 1]; - float dequant_weight = - (static_cast(weight_4bit) - 8.0f) * weight_scale + - weight_zero; - - if (should_print) { - std::cout << int(weight_4bit) << ", "; - } - - sum += input_val * dequant_weight; - } - - if (should_print) { - std::cout << std::endl; - } - - // Store result - int64_t output_idx = b * out_features + out_f; - ref_data[output_idx] = sum; - } - } -} - -// Custom FLOP calculator for quantized linear operation -int64_t quantized_linear_flop_calculator(const TestCase& test_case) { - if (test_case.num_inputs() < 4 || test_case.num_outputs() < 1) { - return 0; - } - - // Get input and weight dimensions - const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); - const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); - - int64_t batch_size = input_sizes[0]; - int64_t in_features = input_sizes[1]; - int64_t out_features = output_sizes[1]; - - // Calculate FLOPs for quantized linear operation - // Each output element requires: - // - in_features multiply-accumulate operations - // - Additional operations for quantization/dequantization - int64_t output_elements = batch_size * out_features; - int64_t ops_per_output = in_features; - - // Add quantization overhead (approximate) - // - Dequantize weight: 2 ops per weight element used (unpack + dequantize) - int64_t quantization_ops = ops_per_output * 2; // Simplified estimate - - int64_t flop = output_elements * (ops_per_output + quantization_ops); - - return flop; -} - -int main(int argc, char* argv[]) { - set_debugging(false); - set_print_output(false); - set_print_latencies(false); - set_use_gpu_timestamps(true); - - print_performance_header(); - std::cout << "Quantized 4-bit Int4 Linear Operation Prototyping Framework" - << std::endl; - print_separator(); - - ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; - - // Execute easy test cases using the new framework with custom FLOP - // calculator - auto results = execute_test_cases( - generate_quantized_linear_test_cases, - quantized_linear_flop_calculator, - "QuantizedLinearInt4", - 0, - 10, - ref_fn); - - return 0; -} diff --git a/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp b/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp deleted file mode 100644 index 084d718b502..00000000000 --- a/backends/vulkan/test/custom_ops/quantized_q4gaw_linear.cpp +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include "utils.h" - -#include - -using namespace executorch::vulkan::prototyping; - -using namespace vkcompute; - -// Linear configuration struct -struct LinearConfig { - int64_t M; // Batch size / number of rows in input - int64_t K; // Input features / columns in input, rows in weight - int64_t N; // Output features / columns in weight - int64_t group_size; // Number of input channels per quantization group - std::string name_suffix; - std::string shader_variant_name = "default"; -}; - -// Utility function to create a test case from a LinearConfig -TestCase create_test_case_from_config( - const LinearConfig& config, - utils::StorageType storage_type, - vkapi::ScalarType input_dtype) { - TestCase test_case; - - // Create a descriptive name for the test case - std::string storage_str = - (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; - std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; - - std::string test_name = "QuantizedLinear4GAW_" + config.name_suffix + "_" + - storage_str + "_" + dtype_str; - test_case.set_name(test_name); - - // Set the operator name for the test case - std::string operator_name = "et_vk.linear_q8ta_q4gaw."; - operator_name += config.shader_variant_name; - test_case.set_operator_name(operator_name); - - // Derive sizes from M, K, N - std::vector input_size = {config.M, config.K}; - std::vector weight_size = { - config.K, config.N / 2}; // Packed 4-bit weights - - // Input tensor (float/half) - [M, K] - ValueSpec input_tensor( - input_size, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::RANDINT); - - if (debugging()) { - print_valuespec_data(input_tensor, "input_tensor"); - } - - float input_scale_val = 1.0f; - ValueSpec input_scale(input_scale_val); - - int32_t input_zero_point_val = 0; - ValueSpec input_zero_point(input_zero_point_val); - - // Group size parameter - ValueSpec group_size_spec(static_cast(config.group_size)); - - // Quantized weight tensor (int8, packed 4-bit) - [K, N/2] - ValueSpec quantized_weight( - weight_size, - vkapi::kChar, // int8 for packed 4-bit quantized weights - storage_type, - utils::kWidthPacked, - DataGenType::RANDINT4); - quantized_weight.set_constant(true); - quantized_weight.set_int4(true); - - if (debugging()) { - print_valuespec_data(quantized_weight, "weight_tensor"); - } - - // Weight quantization scales (float/half, per-group) - [N, K/group_size] - std::vector weight_scales_size = { - config.N, config.K / config.group_size}; - ValueSpec weight_scales( - weight_scales_size, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::RANDOM_SCALES); - weight_scales.set_constant(true); - - if (debugging()) { - print_valuespec_data(weight_scales, "weight_scales"); - } - - // Weight zeros (int32, per-group) - [N, K/group_size] - ValueSpec weight_zeros( - weight_scales_size, - vkapi::kInt, // int32 for zeros - storage_type, - utils::kWidthPacked, - DataGenType::ZEROS); - weight_zeros.set_constant(true); - - ValueSpec weight_sums( - {config.N}, // Per output features - vkapi::kFloat, - storage_type, - utils::kWidthPacked, - DataGenType::ZEROS); - weight_sums.set_constant(true); - - // Compute weight_sums data based on quantized weights - int64_t in_features = config.K; - int64_t out_features = config.N; - - ValueSpec orig_OC(static_cast(config.N)); - - // Bias (optional, float/half) - [N] - ValueSpec bias( - {config.N}, // Per output feature - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::ZEROS); - bias.set_constant(true); - - // Output tensor (float/half) - [M, N] - ValueSpec output( - {config.M, config.N}, - input_dtype, - storage_type, - utils::kWidthPacked, - DataGenType::ZEROS); - - // Add all specs to test case - test_case.add_input_spec(input_tensor); - test_case.add_input_spec(input_scale); - test_case.add_input_spec(input_zero_point); - test_case.add_input_spec(quantized_weight); - test_case.add_input_spec(weight_sums); - test_case.add_input_spec(weight_scales); - test_case.add_input_spec(weight_zeros); - test_case.add_input_spec(orig_OC); - test_case.add_input_spec(group_size_spec); - test_case.add_input_spec(bias); - - test_case.add_output_spec(output); - - return test_case; -} - -// Generate easy test cases for quantized linear operation (for debugging) -std::vector generate_quantized_linear_easy_cases() { - std::vector test_cases; - - // Single simple configuration for debugging - int M = 4; - int K = 32; - int N = 32; - int group_size = 8; - - LinearConfig config = { - M, // Batch size - K, // Input features - N, // Output features - group_size, // Group size - "simple", // descriptive name - "noint8" // shader variant name - }; - - // Test with both storage types and data types for completeness - std::vector storage_types = { - utils::kTexture3D, utils::kBuffer}; - std::vector float_types = {vkapi::kFloat}; - - // Generate test cases for each combination - for (const auto& storage_type : storage_types) { - for (const auto& input_dtype : float_types) { - test_cases.push_back( - create_test_case_from_config(config, storage_type, input_dtype)); - } - } - - return test_cases; -} - -// Generate test cases for quantized linear operation -std::vector generate_quantized_linear_test_cases() { - std::vector test_cases; - - std::vector configs = { - {8, 64, 32, 8, "correctness_1_64_32_g8"}, - {8, 128, 64, 16, "correctness_1_128_64_g16"}, - {8, 256, 128, 32, "correctness_1_256_128_g32"}, - {32, 64, 32, 8, "correctness_32_64_32_g8"}, - {32, 128, 64, 16, "correctness_32_128_64_g16"}, - {32, 256, 128, 32, "correctness_32_256_128_g32"}, - {1, 256, 128, 32, "todo"}, - // Performance test cases - {1, 2048, 2048, 128, "todo"}, - {128, 2048, 2048, 128, "performance_128_2048_2048_g64"}, - {248, 2048, 2048, 128, "performance_128_2048_2048_g64"}, - {1024, 2048, 2048, 128, "performance_128_2048_2048_g64"}, - // {16384, 576, 128, 32, "performance_16384_576_128_g32"} - }; - - // Test with different storage types and data types - std::vector storage_types = { - utils::kTexture3D, utils::kBuffer}; - - // Generate test cases for each combination - for (const auto& config : configs) { - for (const auto& storage_type : storage_types) { - // Test both with and without shader int8 dot product - test_cases.push_back( - create_test_case_from_config(config, storage_type, vkapi::kFloat)); - - // LinearConfig no_int_config = config; - // no_int_config.name_suffix = config.name_suffix + "_noint8"; - // no_int_config.shader_variant_name = "noint8"; - - // test_cases.push_back(create_test_case_from_config( - // no_int_config, storage_type, vkapi::kFloat)); - } - } - - return test_cases; -} - -// Helper function to unpack 4-bit values from int8 -std::pair unpack_4bit(int8_t packed) { - // Extract lower 4 bits and upper 4 bits - int8_t lower = packed & 0x0F; - int8_t upper = (packed >> 4) & 0x0F; - - // Sign extend from 4-bit to 8-bit - if (lower & 0x08) - lower |= 0xF0; - if (upper & 0x08) - upper |= 0xF0; - - return std::make_pair(lower, upper); -} - -// Reference implementation for quantized linear operation -void quantized_linear_reference_impl(TestCase& test_case) { - static constexpr int64_t kRefDimSizeLimit = 300; - // Extract input specifications - int32_t idx = 0; - const ValueSpec& input_spec = test_case.inputs()[idx++]; - const ValueSpec& input_scale_spec = test_case.inputs()[idx++]; - const ValueSpec& input_zeros_spec = test_case.inputs()[idx++]; - const ValueSpec& weight_spec = test_case.inputs()[idx++]; - const ValueSpec& weight_sums_spec = test_case.inputs()[idx++]; - (void)weight_sums_spec; - const ValueSpec& weight_scales_spec = test_case.inputs()[idx++]; - const ValueSpec& weight_zeros_spec = test_case.inputs()[idx++]; - const ValueSpec& orig_OC = test_case.inputs()[idx++]; - (void)orig_OC; - const ValueSpec& group_size_spec = test_case.inputs()[idx++]; - const ValueSpec& bias_spec = test_case.inputs()[idx++]; - - // Extract output specification (mutable reference) - ValueSpec& output_spec = test_case.outputs()[0]; - - // Get tensor dimensions - auto input_sizes = input_spec.get_tensor_sizes(); // [batch_size, in_features] - auto weight_sizes = - weight_spec.get_tensor_sizes(); // [in_features, out_features/2] - auto output_sizes = - output_spec.get_tensor_sizes(); // [batch_size, out_features] - - int64_t batch_size = input_sizes[0]; - int64_t in_features = input_sizes[1]; - int64_t out_features = output_sizes[1]; - int64_t group_size = group_size_spec.get_int_value(); - - // Skip for large tensors since computation time will be extremely slow - if (batch_size > kRefDimSizeLimit || in_features > kRefDimSizeLimit || - out_features > kRefDimSizeLimit) { - throw std::invalid_argument( - "One or more dimensions (batch_size, in_features, out_features) exceed the allowed limit for reference implementation."); - } - - if (input_spec.dtype != vkapi::kFloat) { - throw std::invalid_argument("Unsupported dtype"); - } - - // Get raw data pointers - auto& input_data = input_spec.get_float_data(); - const float input_scale = input_scale_spec.get_float_value(); - const int32_t input_zero_point = input_zeros_spec.get_int_value(); - - auto& weight_data = weight_spec.get_int8_data(); - auto& weight_scales_data = weight_scales_spec.get_float_data(); - auto& weight_zeros_data = weight_zeros_spec.get_int32_data(); - auto& bias_data = bias_spec.get_float_data(); - - // Calculate number of output elements - int64_t num_output_elements = batch_size * out_features; - - auto& ref_data = output_spec.get_ref_float_data(); - ref_data.resize(num_output_elements); - - // Perform quantized linear transformation (matrix multiplication) - for (int64_t b = 0; b < batch_size; ++b) { - for (int64_t out_f = 0; out_f < out_features; ++out_f) { - float sum = 0.0f; - - bool should_print = b == 0 && out_f == 0; - should_print = false; - - if (should_print) { - std::cout << "Weights seen: "; - } - - // Matrix multiplication: output[b][out_f] = sum(input[b][in_f] * - // weight[out_f][in_f]) - for (int64_t in_f = 0; in_f < in_features; ++in_f) { - // Get input value and dequantize - int64_t input_idx = b * in_features + in_f; - - float quant_input = - std::round(input_data[input_idx] / input_scale) + input_zero_point; - quant_input = std::min(std::max(quant_input, -128.0f), 127.0f); - float dequant_input = (quant_input - input_zero_point) * input_scale; - - // Get weight value and dequantize (4-bit group affine quantization) - int64_t group_idx = in_f / group_size; - int64_t scales_idx = group_idx * out_features + out_f; - - // Get packed weight value - int64_t weight_idx = in_f * (out_features / 2) + (out_f / 2); - int8_t packed_weight = weight_data[weight_idx]; - - // Unpack 4-bit weight - auto unpacked = unpack_4bit(packed_weight); - int8_t weight_4bit = - (out_f % 2 == 0) ? unpacked.first : unpacked.second; - - // Dequantize weight using group affine quantization - float weight_scale = weight_scales_data[scales_idx]; - int32_t weight_zero = weight_zeros_data[scales_idx]; - float dequant_weight = - (static_cast(weight_4bit) - weight_zero) * weight_scale; - - if (should_print) { - std::cout << int(weight_4bit) << ", "; - } - - sum += dequant_input * dequant_weight; - } - - if (should_print) { - std::cout << std::endl; - } - - // Add bias and store result - sum += bias_data[out_f]; - int64_t output_idx = b * out_features + out_f; - ref_data[output_idx] = sum; - } - } -} - -// Custom FLOP calculator for quantized linear operation -int64_t quantized_linear_flop_calculator(const TestCase& test_case) { - if (test_case.num_inputs() < 6 || test_case.num_outputs() < 1) { - return 0; - } - - // Get input and weight dimensions - const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes(); - const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes(); - - int64_t batch_size = input_sizes[0]; - int64_t in_features = input_sizes[1]; - int64_t out_features = output_sizes[1]; - - // Calculate FLOPs for quantized linear operation - // Each output element requires: - // - in_features multiply-accumulate operations - // - Additional operations for quantization/dequantization - int64_t output_elements = batch_size * out_features; - int64_t ops_per_output = in_features; - - // Add quantization overhead (approximate) - // - Dequantize input: 1 op per input element used - // - Dequantize weight: 2 ops per weight element used (unpack + dequantize) - // - Add bias: 1 op per output element - int64_t quantization_ops = ops_per_output * 2 + 1; // Simplified estimate - - int64_t flop = output_elements * (ops_per_output + quantization_ops); - - return flop; -} - -int main(int argc, char* argv[]) { - set_debugging(false); - set_print_output(false); - set_print_latencies(false); - set_use_gpu_timestamps(true); - - print_performance_header(); - std::cout - << "Quantized 4-bit Group Affine Weights Linear Operation Prototyping Framework" - << std::endl; - print_separator(); - - ReferenceComputeFunc ref_fn = quantized_linear_reference_impl; - - // Execute easy test cases using the new framework with custom FLOP - // calculator - auto results = execute_test_cases( - generate_quantized_linear_test_cases, - quantized_linear_flop_calculator, - "QuantizedLinear4GAW", - 0, - 10, - ref_fn); - - return 0; -} From 7e75e5fd05cea5442c55ce7f7c2e0e82a7ee320e Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 11:48:17 -0700 Subject: [PATCH 4/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- backends/vulkan/targets.bzl | 3 ++- backends/vulkan/test/custom_ops/q4gsw_linear.cpp | 11 ++++++++--- backends/vulkan/utils.py | 8 ++++---- 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 775341d420d..40831856ac5 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -330,7 +330,8 @@ def define_common_targets(is_fbcode = False): "//executorch/exir:tensor", "//executorch/exir/backend/canonical_partitioners:config_partitioner_lib", "//executorch/backends/vulkan/serialization:lib", - ] + ], + typing = True, ) runtime.python_library( diff --git a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp index a6adfdc2d5c..1c09fdd471f 100644 --- a/backends/vulkan/test/custom_ops/q4gsw_linear.cpp +++ b/backends/vulkan/test/custom_ops/q4gsw_linear.cpp @@ -267,9 +267,14 @@ std::vector generate_quantized_linear_test_cases() { config.test_case_name = generated_test_case_name; for (const auto& storage_type : storage_types) { - // Test both activation+weight quantized and weight only quantized - test_cases.push_back( - create_test_case_from_config(config, storage_type, vkapi::kFloat)); + // Test both activation+weight quantized and weight only quantized, but + // only if the current device supports int8 dot product + if (vkcompute::api::context() + ->adapter_ptr() + ->supports_int8_dot_product()) { + test_cases.push_back( + create_test_case_from_config(config, storage_type, vkapi::kFloat)); + } LinearConfig wo_quant_config = config; wo_quant_config.op_name = "linear_q4gsw"; diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 474c6386a9e..96f200eecbc 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1142,17 +1142,17 @@ def maybe_skip_q_dq_arg_chain( # If the arg is a view copy node, check if the original node is a dequant node if is_dequant_node(arg) or ( - is_view_copy_node(arg) and is_dequant_node(arg.args[0]) + is_view_copy_node(arg) and is_dequant_node(arg.args[0]) # pyre-ignore[6] ): + dequant_node = arg if is_view_copy_node(arg): dequant_node = arg.args[0] - else: - dequant_node = arg - quant_node = dequant_node.args[0] + quant_node = dequant_node.args[0] # pyre-ignore[16] assert isinstance(quant_node, torch.fx.Node) source_arg = quant_node.args[0] assert isinstance(source_arg, torch.fx.Node) + assert isinstance(dequant_node, torch.fx.Node) return source_arg, quant_node, dequant_node else: return arg, None, None From 45d9b43680b6dbc1c9612edc43d489c41f804ba6 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 12:32:38 -0700 Subject: [PATCH 5/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- backends/vulkan/patterns/quantized_linear.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/backends/vulkan/patterns/quantized_linear.py b/backends/vulkan/patterns/quantized_linear.py index 2858cda111a..882d0d41e6d 100644 --- a/backends/vulkan/patterns/quantized_linear.py +++ b/backends/vulkan/patterns/quantized_linear.py @@ -181,6 +181,9 @@ def is_input_dynamic_perchannel_quantized(self) -> bool: if self.quantize_input_node is None: return False + if not isinstance(self.input_scales_node, torch.fx.Node): + return False + # For dynamic quantization, input scale node should be a getitem operator # retrieving the output of a choose_qparams op if self.input_scales_node.target != operator.getitem: From 31f68011a64461224ee1bef6c18aa68b211d3c96 Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 14:04:06 -0700 Subject: [PATCH 6/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned] --- .github/workflows/pull.yml | 1 + backends/vulkan/test/test_vulkan_delegate.py | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index fa566afca85..67fa28b2868 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -939,6 +939,7 @@ jobs: # Run e2e testing for selected operators. More operators will be tested via this # route in the future. python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*" + python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*torchao*" nxp-build-test: name: nxp-build-test diff --git a/backends/vulkan/test/test_vulkan_delegate.py b/backends/vulkan/test/test_vulkan_delegate.py index a832915359f..01547d7140d 100644 --- a/backends/vulkan/test/test_vulkan_delegate.py +++ b/backends/vulkan/test/test_vulkan_delegate.py @@ -2651,6 +2651,7 @@ def forward(self, x): rtol=1e-1, ) + @unittest.skip("Cannot run on swiftshader due to no 8-bit int support") def test_vulkan_backend_torchao_8da4w_quantized_linear(self): """ Test TorchAO 8da4w quantization (int8 dynamic activation + int4 weight) with Vulkan backend. From df7249603431c4f97e21ecf24d34f2a8dce0eacb Mon Sep 17 00:00:00 2001 From: ssjia Date: Mon, 8 Sep 2025 14:25:08 -0700 Subject: [PATCH 7/7] Update on "[ET-VK] Implemement linear_dq8ta_q4gsw" Title says it all! Build upon the support for quantized linear introduced in the previous diffs to enable dynamically quantized linear. Also included in this diff is a cleanup of the glslh files used across quantized linear implementations. Differential Revision: [D81931060](https://our.internmc.facebook.com/intern/diff/D81931060/) [ghstack-poisoned]