diff --git a/backends/vulkan/_passes/TARGETS b/backends/vulkan/_passes/TARGETS index 8558a2eea93..aed41114ada 100644 --- a/backends/vulkan/_passes/TARGETS +++ b/backends/vulkan/_passes/TARGETS @@ -34,20 +34,6 @@ runtime.python_library( ], ) -runtime.python_library( - name = "int4_weight_only_quantizer", - srcs = [ - "int4_weight_only_quantizer.py", - ], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - "//executorch/backends/vulkan:custom_ops_lib", - "//pytorch/ao:torchao", - ] -) - runtime.python_library( name = "squeeze_unsqueeze_inputs", srcs = [ @@ -161,7 +147,6 @@ runtime.python_library( ":fuse_patterns", ":fuse_quantized_ops", ":insert_prepack_nodes", - ":int4_weight_only_quantizer", ":remove_asserts", ":remove_local_scalar_dense", ":remove_redundant_ops", diff --git a/backends/vulkan/_passes/__init__.py b/backends/vulkan/_passes/__init__.py index 2c4588ac43d..f4ef6b2ac0e 100644 --- a/backends/vulkan/_passes/__init__.py +++ b/backends/vulkan/_passes/__init__.py @@ -12,9 +12,6 @@ FuseQuantizedOpsTransform, ) from executorch.backends.vulkan._passes.insert_prepack_nodes import insert_prepack_nodes -from executorch.backends.vulkan._passes.int4_weight_only_quantizer import ( - VkInt4WeightOnlyQuantizer, -) from executorch.backends.vulkan._passes.remove_asserts import ( remove_asserts, RemoveAssertsTransform, @@ -35,7 +32,6 @@ "FusePatternsPass", "FuseQuantizedOpsTransform", "insert_prepack_nodes", - "VkInt4WeightOnlyQuantizer", "remove_asserts", "RemoveAssertsTransform", "RemoveLocalScalarDenseOpsTransform", diff --git a/backends/vulkan/_passes/fuse_quantized_ops.py b/backends/vulkan/_passes/fuse_quantized_ops.py index 3d3214bb4ee..ca9f7541159 100644 --- a/backends/vulkan/_passes/fuse_quantized_ops.py +++ b/backends/vulkan/_passes/fuse_quantized_ops.py @@ -210,278 +210,6 @@ def fuse_into_linear_qcnw_node( graph_module.graph.erase_node(dq_weight_node) -######################### -## linear_qta8a_qga4w ## -######################### - - -def _is_dequantize_affine_node(node: torch.fx.Node) -> bool: - """Check if a node is a dequantize_affine operation.""" - return ( - node.op == "call_function" - and node.target is not None - and hasattr(node.target, "__name__") - and "dequantize_affine" in getattr(node.target, "__name__", "") - ) - - -def _is_view_copy_node(node: torch.fx.Node) -> bool: - """Check if a node is a view_copy operation.""" - return ( - node.op == "call_function" - and node.target is not None - and hasattr(node.target, "__name__") - and "view_copy" in getattr(node.target, "__name__", "") - ) - - -def _validate_qta8a_qga4w_nodes( - input_node: torch.fx.node.Argument, weight_node: torch.fx.node.Argument -) -> Optional[torch.fx.Node]: - """ - Validate input and weight nodes for QTA8A_QGA4W pattern. - Returns the actual input node (after handling view operations) or None if invalid. - """ - # Type checking - ensure we have torch.fx.Node objects - if not isinstance(weight_node, torch.fx.Node) or not isinstance( - input_node, torch.fx.Node - ): - return None - - # Input may be preprocessed with a view node - actual_input_node = input_node - if _is_view_copy_node(input_node): - actual_input_node = input_node.args[0] - if not isinstance(actual_input_node, torch.fx.Node): - return None - - # Check if input is dequantized with dequantize_affine (from dynamic quantization) - if not _is_dequantize_affine_node(actual_input_node): - return None - - # Check if weight is dequantized with dequantize_affine - if not _is_dequantize_affine_node(weight_node): - return None - - return actual_input_node - - -def _extract_weight_params( - program: ExportedProgram, weight_node: torch.fx.Node -) -> Optional[Tuple[torch.fx.Node, torch.fx.Node, torch.fx.Node]]: - """Extract and validate weight parameters from dequantize_affine node.""" - # Get the original quantized weight and quantization parameters - if len(weight_node.args) < 4: - return None - - orig_weight = weight_node.args[0] - weight_scales = weight_node.args[2] - weight_zeros = weight_node.args[3] - - # Type checking - if not isinstance(orig_weight, torch.fx.Node) or not is_param_node( - program, orig_weight - ): - return None - if not isinstance(weight_scales, torch.fx.Node) or not is_param_node( - program, weight_scales - ): - return None - if not isinstance(weight_zeros, torch.fx.Node) or not is_param_node( - program, weight_zeros - ): - return None - - return orig_weight, weight_scales, weight_zeros - - -def _validate_4bit_quantization(weight_tensor: torch.Tensor) -> bool: - """Check if weight tensor is quantized to 4 bits (values in [-8, 7] range).""" - quant_min = weight_tensor.min().item() - quant_max = weight_tensor.max().item() - return quant_min >= -8 and quant_max <= 7 - - -def _calculate_group_size( - orig_weight_tensor: torch.Tensor, weight_scales_tensor: torch.Tensor -) -> Optional[int]: - """Calculate and validate group size from weight and scales tensors.""" - out_features, in_features = orig_weight_tensor.shape - - if len(weight_scales_tensor.shape) != 2: - return None - - scales_out_features, num_groups = weight_scales_tensor.shape - - if scales_out_features != out_features: - return None - - group_size = in_features // num_groups - if in_features % group_size != 0: - return None - - return group_size - - -def matches_linear_qta8a_qga4w_pattern( - program: ExportedProgram, node: torch.fx.Node -) -> Optional[Tuple[int, int]]: - """ - Checks if the nodes surrounding a linear node matches the pattern for dynamic - activation + grouped weight quantized linear (QTA8A_QGA4W). - - This pattern involves: - 1. Dynamic quantization of input activations (8-bit) - 2. Grouped quantization of weights (4-bit with group size) - - The expected pattern from Int8DynActInt4WeightQuantizer is: - scale, zero_point = choose_qparams_affine(input) - quantized_input = quantize_affine(input, scale, zero_point) - dequantized_input = dequantize_affine(quantized_input, ...) - dequantized_weight = dequantize_affine(weight, weight_scales, weight_zeros) - output = linear(dequantized_input, dequantized_weight) - - If the pattern matches, return (group_size, weight_bits), otherwise None. - """ - if not utils.is_linear_node(node): - return None - - input_node = node.args[0] - weight_node = node.args[1] - - # Validate nodes and get actual input node - actual_input_node = _validate_qta8a_qga4w_nodes(input_node, weight_node) - if actual_input_node is None: - return None - - # Extract weight parameters - if not isinstance(weight_node, torch.fx.Node): - return None - weight_params = _extract_weight_params(program, weight_node) - if weight_params is None: - return None - - orig_weight, weight_scales, weight_zeros = weight_params - - # Get tensors to analyze the quantization scheme - orig_weight_tensor = get_param_tensor(program, orig_weight) - weight_scales_tensor = get_param_tensor(program, weight_scales) - weight_zeros_tensor = get_param_tensor(program, weight_zeros) - - if not isinstance(orig_weight_tensor, torch.Tensor): - return None - if not isinstance(weight_scales_tensor, torch.Tensor): - return None - if not isinstance(weight_zeros_tensor, torch.Tensor): - return None - - # Check if weight is quantized to 4 bits - if not _validate_4bit_quantization(orig_weight_tensor): - return None - - # Calculate group size - group_size = _calculate_group_size(orig_weight_tensor, weight_scales_tensor) - if group_size is None: - return None - - # Verify this is 4-bit grouped quantization - weight_bits = 4 - - return group_size, weight_bits - - -def fuse_into_linear_qta8a_qga4w_node( - program: ExportedProgram, - graph_module: torch.fx.GraphModule, - linear_node: torch.fx.Node, - group_size: int, - weight_bits: int, -) -> None: - """ - Fuse the dynamic activation + grouped weight quantized linear pattern into - a single linear_qta8a_qga4w operator. - - The pattern: - dequantized_input = dequantize_affine(quantized_input, block_size, scale, zero_point, ...) - dequantized_weight = dequantize_affine(weight, block_size, weight_scales, weight_zeros, ...) - output = linear(dequantized_input, dequantized_weight) - - Becomes: - output = linear_qta8a_qga4w(quantized_input, input_scale, input_zero_point, - weight, group_size, weight_scales, weight_zeros) - """ - dq_input_node = linear_node.args[0] - dq_weight_node = linear_node.args[1] - - assert isinstance(dq_input_node, torch.fx.Node) - - input_view_node = None - # Input may be preprocessed with a view node - if ( - dq_input_node.op == "call_function" - and dq_input_node.target is not None - and hasattr(dq_input_node.target, "__name__") - and "view_copy" in getattr(dq_input_node.target, "__name__", "") - ): - input_view_node = dq_input_node - dq_input_node = dq_input_node.args[0] - assert isinstance(dq_input_node, torch.fx.Node) - - assert isinstance(dq_input_node, torch.fx.Node) - assert isinstance(dq_weight_node, torch.fx.Node) - - # Get the quantized input and quantization parameters from the input dequantize_affine node - # Args: (input, block_size, scale, zero_point, input_dtype, quant_min, quant_max, output_dtype) - quantized_input = dq_input_node.args[0] - input_scale = dq_input_node.args[2] # scale is the 3rd argument - input_zero_point = dq_input_node.args[3] if len(dq_input_node.args) > 3 else None - - # Get the weight and its quantization parameters from dequantize_affine - # Args: (weight, block_size, weight_scales, weight_zeros, input_dtype, quant_min, quant_max, output_dtype) - orig_weight = dq_weight_node.args[0] - weight_scales = dq_weight_node.args[2] - weight_zeros = dq_weight_node.args[3] - - # Pack the 4-bit weight tensor for efficient storage - assert isinstance(orig_weight, torch.fx.Node) - orig_weight_tensor = get_param_tensor(program, orig_weight) - assert isinstance(orig_weight_tensor, torch.Tensor) - packed_weight_tensor = pack_4bit_weight_tensor(orig_weight_tensor) - utils.update_program_state_dict( - program, - orig_weight.name, - packed_weight_tensor, - ) - # Update the metadata to reflect the new packed shape - orig_weight.meta["val"] = orig_weight.meta["val"][:, ::2].to(torch.uint8) - - # Create the linear_qta8a_qga4w node - with graph_module.graph.inserting_before(linear_node): - linear_qta8a_qga4w_node = graph_module.graph.create_node( - "call_function", - exir_ops.edge.et_vk.linear_qta8a_qga4w.default, - ( - quantized_input, # quantized input (int8) - input_scale, # mat1_scale - input_zero_point, # mat1_zero_point - orig_weight, # mat2_data (packed 4-bit weights) - group_size, # group_size (int) - weight_scales, # weight_scales - weight_zeros, # weight_zeros - ), - ) - - # Replace the linear node with the new fused node - linear_node.replace_all_uses_with(linear_qta8a_qga4w_node) - - # Erase nodes in the correct order (users first, then dependencies) - graph_module.graph.erase_node(linear_node) - if input_view_node is not None: - graph_module.graph.erase_node(input_view_node) - graph_module.graph.erase_node(dq_weight_node) - graph_module.graph.erase_node(dq_input_node) - - class FuseQuantizedOpsTransform(ExportPass): def __init__(self, exported_program: ExportedProgram) -> None: super().__init__() @@ -498,15 +226,6 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: ) continue - # Check for linear_qta8a_qga4w pattern (dynamic activation + grouped weight quantization) - qta8a_qga4w_details = None - if qta8a_qga4w_details is not None: - group_size, weight_bits = qta8a_qga4w_details - fuse_into_linear_qta8a_qga4w_node( - self.program, graph_module, node, group_size, weight_bits - ) - continue - graph_module.recompile() dead_code_elimination_pass(graph_module) diff --git a/backends/vulkan/_passes/int4_weight_only_quantizer.py b/backends/vulkan/_passes/int4_weight_only_quantizer.py deleted file mode 100644 index 34ff5937822..00000000000 --- a/backends/vulkan/_passes/int4_weight_only_quantizer.py +++ /dev/null @@ -1,283 +0,0 @@ -# pyre-unsafe -import logging -from typing import Any, Callable, Dict, Optional, Type - -import executorch.backends.vulkan.custom_ops_lib # noqa - -import torch -import torch.nn.functional as F - -from torchao.quantization.unified import Quantizer -from torchao.quantization.utils import groupwise_affine_quantize_tensor - - -# TODO: import from from torchao.quantization.GPTQ.GPTQ import _check_linear_int4_k -# Once diff train catches up -def _check_linear_int4_k(k, group_size=1, inner_k_tiles=None): - """ - Check if the dimensions are compatible with int4 quantization. - - Args: - k: The dimension size to check - group_size: The group size for quantization - inner_k_tiles: The inner k tiles size - - Returns: - bool: Whether the dimensions are compatible - """ - k_divisible_by_group_size = k % group_size == 0 - if inner_k_tiles is not None: - k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 - return k_divisible_by_group_size and k_divisible_by_16_times_inner_k_tiles - return k_divisible_by_group_size - - -# This module is copied from torchao.quantization.GPTQ.WeightOnlyInt4Linear with -# changes at the annotated lines. -class VkWeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ["in_features", "out_features"] - in_features: int - out_features: int - weight: torch.Tensor - - def __init__( - self, - in_features: int, - out_features: int, - # TODO: remove dtype field, not used - bias=False, - device=None, - dtype=None, - groupsize: int = 128, - inner_k_tiles: int = 8, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - ) -> None: - super().__init__() - self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) - if self.padding: - from torchao.utils import find_multiple - - self.origin_in_features = in_features - # pyre-ignore[6]: Incompatible parameter type - in_features = find_multiple(in_features, 1024) - - self.use_bias = bias - self.in_features = in_features - self.out_features = out_features - self.device = device - self.groupsize = groupsize - self.inner_k_tiles = inner_k_tiles - self.precision = precision - self.scales_precision = scales_precision - - if dtype is not None: - raise ValueError("Please specify 'precision' instead of 'dtype'") - - assert out_features % 8 == 0, "require out_features % 8 == 0" - assert ( - in_features % (inner_k_tiles * 16) == 0 - ), "require in_features % (innerKTiles * 16) == 0" - # In the original implementation, the weight buffer is registered with the packed - # sizes, i.e. the result of calling the _convert_weight_to_int4pack operator. - # However, the Vulkan implementation does not expect the weights to be packed - # therefore the weight tensor is registered with the unpacked sizes instead. - # Note that in_features is divided by 2 because each `uint8` tensor element - # contains 2 4-bit packed values. - self.register_buffer( - "weight", - torch.empty( - (out_features, in_features // 2), - dtype=torch.uint8, - device=device, - ), - ) - self.dtype = dtype - self.register_buffer( - "scales_and_zeros", - torch.empty( - (in_features // groupsize, out_features, 2), - dtype=self.scales_precision, - device=device, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.empty((out_features,), dtype=torch.float32, device=device), - ) - - def forward(self, input: torch.Tensor) -> torch.Tensor: - if self.padding: - input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) - # The forward method is replaced. In the original implementation, the forward - # method is torchao.quantization.GPTQ.linear_forward_int4; here a Vulkan custom - # operator is called instead. - r = torch.ops.et_vk.linear_weight_int4( - input, - self.weight, - self.groupsize, - self.scales_and_zeros, - self.inner_k_tiles, - ) - if self.use_bias: - return r + self.bias - return r - - -# This function is coped from torchao.quantization.GPTQ._replace_linear_int4 -# with small changes at the annotated locations. -def _vk_replace_linear_int4( - module: torch.nn.Module, - groupsize: int, - inner_k_tiles: Optional[int], - padding_allowed: bool, - skip_layer_func: Optional[Callable] = None, - precision: torch.dtype = torch.bfloat16, - scales_precision: torch.dtype = torch.bfloat16, - # Use custom vulkan linear layer as default - linear_class: Type[torch.nn.Module] = VkWeightOnlyInt4Linear, - copy_weights: bool = False, -): - for name, child in module.named_children(): - if isinstance(child, torch.nn.Linear) and ( - skip_layer_func is None or not skip_layer_func(child.weight) - ): - # Add an additional condition that the out/in features must not exceed the - # `feature_limit` argument. - if ( - _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) - or padding_allowed - ): - new_linear = linear_class( - child.in_features, - child.out_features, - bias=child.bias is not None, - device=child.weight.device, - groupsize=groupsize, - inner_k_tiles=inner_k_tiles, - precision=precision, - scales_precision=scales_precision, - ) - if copy_weights and child.weight.device != torch.device("meta"): - # pyre-fixme[16]: `Module` has no attribute `weight`. - new_linear.weight = child.weight - if child.bias is not None: - # pyre-fixme[16]: `Module` has no attribute `bias`. - new_linear.bias = child.bias - setattr(module, name, new_linear) - else: - _vk_replace_linear_int4( - child, - groupsize, - inner_k_tiles, - padding_allowed, - skip_layer_func, - precision, - scales_precision, - linear_class, - copy_weights, - ) - - -# This module is copied from torchao.quantization.GPTQ.Int4WeightOnlyQuantizer -# with some changes at the annotated lines. -class VkInt4WeightOnlyQuantizer(Quantizer): - def __init__( - self, - groupsize: int = 256, - padding_allowed: bool = True, - inner_k_tiles: Optional[int] = 8, - device: torch.device = torch.device("cpu"), # noqa - precision: torch.dtype = torch.float32, - ) -> None: - super().__init__() - assert inner_k_tiles in [2, 4, 8] - assert groupsize in [32, 64, 128, 256] - - self.inner_k_tiles = inner_k_tiles - self.groupsize: int = groupsize - self.padding_allowed: bool = padding_allowed - self.device: torch.device = device - self.precision: torch.dtype = precision - - @torch.no_grad() - def _create_quantized_state_dict( - self, model: torch.nn.Module - ) -> Dict[str, torch.Tensor]: - cur_state_dict = model.state_dict() - for fqn, mod in model.named_modules(): - # Add additional check to make sure features do not exceed feature limit - if isinstance(mod, torch.nn.Linear): - out_features = mod.out_features - in_features = mod.in_features - logging.info(f"linear: {fqn}, in={in_features}, out={out_features}") - - assert ( - in_features % self.groupsize == 0 - ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" - - weight = mod.weight.data - if not _check_linear_int4_k( - in_features, self.groupsize, self.inner_k_tiles - ): - if self.padding_allowed: - import torch.nn.functional as F - - from torchao.utils import find_multiple - - logging.warn( - f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" - ) - # pyre-ignore[6]: Incompatible parameter type - padded_in_features = find_multiple(in_features, 1024) - weight = F.pad( - weight, pad=(0, padded_in_features - in_features) - ) - else: - logging.warn( - f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " - + "and that groupsize and inner_k_tiles*16 evenly divide into it" - ) - continue - (w_int4x8, scales_and_zeros) = groupwise_affine_quantize_tensor( - weight, - 4, # n_bit - self.groupsize, - self.precision, # dtype for scales_and_zeros - ) - # If the packing of 2 4-bit values into a single 8-bit value was not - # performed in the previous function call, then do it manually now. - if w_int4x8.shape == weight.shape: - w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to( - torch.uint8 - ) - # In the original implementation, w_int4x8 is packed via calling the - # _convert_weight_to_int4pack operator before storing the weight. However - # the Vulkan implementation does not expect the weights to be packed, so - # the w_int4x8 tensor is stored as the weight instead. - cur_state_dict[f"{fqn}.weight"] = w_int4x8.to(self.device) - cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to( - self.device - ) - return cur_state_dict - - def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: - _vk_replace_linear_int4( - model, - self.groupsize, - self.inner_k_tiles, - self.padding_allowed, - skip_layer_func=None, - precision=self.precision, - scales_precision=self.precision, - ) - return model - - def quantize( - self, model: torch.nn.Module, *args: Any, **kwargs: Any - ) -> torch.nn.Module: - state_dict = self._create_quantized_state_dict(model) - model = self._convert_for_runtime(model) - model.load_state_dict(state_dict, strict=False) - return model diff --git a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py index c415249383e..25b28ce3117 100644 --- a/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py +++ b/backends/vulkan/_passes/squeeze_unsqueeze_inputs.py @@ -22,7 +22,6 @@ class SqueezeUnsqueezeInputs(ExportPass): _squeezable_ops: Set[OpType] = { - exir_ops.edge.et_vk.linear_weight_int4.default, exir_ops.edge.aten.relu.default, exir_ops.edge.aten.gelu.default, } diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 336ca2117a0..56e803b9127 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -155,38 +155,6 @@ def grid_priors_out_impl( ) lib.impl(name, grid_priors_out_impl, "CompositeExplicitAutograd") -######################## -## linear_weight_int4 ## -######################## - - -def linear_weight_int4_impl( - x: torch.Tensor, - weights_4x8: torch.Tensor, - groupsize: int, - scales_and_zeros: torch.Tensor, - inner_k_tiles: int, -): - original_x_size = x.size() - out_features = weights_4x8.size(0) - x = x.reshape(-1, original_x_size[-1]) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( - weights_4x8, inner_k_tiles - ) - out = torch.ops.aten._weight_int4pack_mm( - x, weight_int4pack, groupsize, scales_and_zeros - ) - out_shape = original_x_size[:-1] + (out_features,) - return out.reshape(out_shape) - - -name = "linear_weight_int4" -lib.define( - f"{name}(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros, int inner_k_tiles) -> Tensor" -) -lib.impl(name, linear_weight_int4_impl, "CompositeExplicitAutograd") -linear_weight_int4_op = getattr(getattr(torch.ops, namespace), name) - ################## ## linear_qcs4w ## ################## @@ -337,95 +305,6 @@ def linear_dq8ca_q4gsw( lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd") linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name) -######################## -## linear_qta8a_qga4w ## -######################## - - -def linear_qta8a_qga4w( - x_quantized: torch.Tensor, - input_scale: torch.Tensor, - input_zero_point: torch.Tensor, - weights_4bit: torch.Tensor, - group_size: int, - weight_scales: torch.Tensor, - weight_zeros: torch.Tensor, -): - """ - Dynamic activation + grouped weight quantized linear (QTA8A_QGA4W). - - Args: - x_quantized: Already quantized input tensor (int8, per-token quantized) - input_scale: Scale for per-token quantization of input (shape: [batch_size]) - input_zero_point: Zero point for per-token quantization of input (shape: [batch_size]) - weights_4bit: Packed 4-bit quantized weights - group_size: Group size for weight quantization (int) - weight_scales: Per-group scales for weights - weight_zeros: Per-group zero points for weights - """ - original_x_shape = x_quantized.shape - feature_dim = original_x_shape[-1] - - # Reshape for processing - x_quantized_2d = x_quantized.reshape(-1, feature_dim) - - # Unpack 4-bit weights - unpacked_weights_shape = weights_4bit.shape - out_features = unpacked_weights_shape[0] - in_features = unpacked_weights_shape[1] - - weights_unpacked = torch.empty( - (out_features, in_features * 2), dtype=torch.int8, device=weights_4bit.device - ) - - weights_unpacked[:, ::2] = weights_4bit >> 4 - weights_unpacked[:, 1::2] = weights_4bit & 0x0F - - # Convert to signed 4-bit range [-8, 7] - weights_unpacked = torch.where( - weights_unpacked > 7, weights_unpacked - 16, weights_unpacked - ) - - # Dequantize weights using grouped quantization - actual_in_features = in_features * 2 - num_groups = actual_in_features // group_size - - # Reshape weights for grouped dequantization - weights_grouped = weights_unpacked.view(out_features, num_groups, group_size) - - # Expand scales and zeros to match grouped weights - scales_expanded = weight_scales.unsqueeze(-1).expand(-1, -1, group_size) - zeros_expanded = weight_zeros.unsqueeze(-1).expand(-1, -1, group_size) - - # Dequantize: (quantized - zero_point) * scale - dq_weights_grouped = (weights_grouped.float() - zeros_expanded) * scales_expanded - dq_weights = dq_weights_grouped.view(out_features, actual_in_features) - - # Dequantize input (per-token) - # For per-token quantization, each token (row) has its own scale and zero_point - x_dequantized = torch.ops.quantized_decomposed.dequantize_per_token( - x_quantized_2d, - input_scale, - input_zero_point, - -128, - 127, - torch.int8, - torch.float32, - ) - - # Perform linear operation - out = torch.nn.functional.linear(x_dequantized, dq_weights) - out_shape = original_x_shape[:-1] + (out_features,) - return out.reshape(out_shape) - - -name = "linear_qta8a_qga4w" -lib.define( - f"{name}(Tensor self, Tensor input_scale, Tensor input_zero_point, Tensor weight, int group_size, Tensor weight_scales, Tensor weight_zeros) -> Tensor" -) -lib.impl(name, linear_qta8a_qga4w, "CompositeExplicitAutograd") -linear_qta8a_qga4w_op = getattr(getattr(torch.ops, namespace), name) - ################# ## qaqw_linear ## ################# @@ -475,9 +354,9 @@ def linear_q8ta_q8csw( lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd") qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name) -################## +####################### ## conv2d_q8ta_q8csw ## -################## +####################### def conv2d_q8ta_q8csw( diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 9f1561fb05e..4c686e0cfc5 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -387,40 +387,6 @@ def register_linear_dqa_qw_ops(): ) -@update_features( - [ - exir_ops.edge.et_vk.linear_weight_int4.default, - ] -) -def register_int4_mm_op(): - return OpFeatures( - inputs_storage=utils.CONTIGUOUS_ANY, - supports_resize=True, - supports_prepacking=True, - ) - - -@update_features( - [ - exir_ops.edge.et_vk.linear_qta8a_qga4w.default, - ] -) -def register_dqlinear_op(): - return OpFeatures( - inputs_storage=[ - utils.CONTIGUOUS_ANY, # input - utils.CONTIGUOUS_BUFFER, # mat1 scales - utils.CONTIGUOUS_BUFFER, # mat1 zeros - utils.NO_STORAGE, # weight (prepacked) - utils.NO_STORAGE, # group size (non tensor) - utils.CONTIGUOUS_BUFFER, # mat2 scales - utils.CONTIGUOUS_BUFFER, # mat2 zeros - ], - supports_resize=True, - supports_prepacking=True, - ) - - @update_features( [ exir_ops.edge.aten._log_softmax.default, diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl deleted file mode 100644 index 150efeef1ad..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.glsl +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} -#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} - -#define WGS ${WGS} - -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} - -layout(push_constant) uniform restrict Block { - ivec4 output_sizes; - ivec4 input_sizes; - ivec4 weight_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int group_size = 64; - -shared VEC4_T partial_sums[WGS][2]; - -$if IO_STORAGE == "buffer": - #define BUFFER_IO -$if WEIGHT_STORAGE == "buffer": - #define BUFFER_WEIGHT - -#include "qlinear_utils.glslh" - -void main() { - const uint lid = gl_LocalInvocationID.x; - const uint n8 = gl_GlobalInvocationID.y; - // The output tensor will have a shape of [n, 1, 1, 1]. Each thread computes - // 8 output elements, so each thread will write to 8 elements starting at the - // tensor index (gid.x * 8, 0, 0, 0). - const uint n = MUL_8(n8); - const uint K4 = DIV_UP_4(input_sizes.x); - - if (n >= output_sizes.x) { - return; - } - - VEC4_T out_texels[2]; - out_texels[0] = VEC4_T(0); - out_texels[1] = VEC4_T(0); - - // initialize the group index to a value larger than the largest possible - uint cur_group_idx = input_sizes.x; - - // Each thread in the work group accumulates a partial result. - for (uint k4 = lid; k4 < DIV_UP_4(input_sizes.x); k4 += WGS) { - const uint k = MUL_4(k4); - const uint group_idx = k / group_size; - - VEC4_T scales[2]; - VEC4_T zeros[2]; - - // Only update the scales/zeros if the current iteration is now working on a - // new quantization group. - if (group_idx != cur_group_idx) { - // The qparams tensor contains the quantization scales and zeros, with - // shape [2, N, K / group_size, 1]. - // Loading a texel from the qparams tensor will return 2 scales and 2 - // zeros for 2 adjacent output channels. - uint qparams_bufi = group_idx * DIV_2(output_sizes.x) + DIV_2(n); - VEC4_T scales_zeros_texels[4]; - $for comp in range(4): - scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; - - scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); - zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); - - scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); - zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); - - cur_group_idx = group_idx; - } - // The input tensor will have a shape of [K, 1, 1, 1]; in each iteration, - // load 4 elements starting from the tensor index (k, 0, 0, 0). - VEC4_T in_texel = load_input_texel_1d(k4); - // Extract each element of the in_texel into a separate vectorized variable; - // these are used to "broadcast" the input values in subsequent fma calls. - VEC4_T in_texel_val[4]; - $for comp in range(4): - in_texel_val[${comp}] = VEC4_T(in_texel[${comp}]); - - uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); - - VEC4_T weight_texels[2]; - $for comp in range(4): - { - weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${comp}); - weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${comp}); - weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${comp}); - weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${comp}); - - weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${comp}); - weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${comp}); - weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${comp}); - weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${comp}); - - weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); - weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); - - out_texels[0] = fma(in_texel_val[${comp}], weight_texels[0], out_texels[0]); - out_texels[1] = fma(in_texel_val[${comp}], weight_texels[1], out_texels[1]); - } - } - - partial_sums[lid][0] = out_texels[0]; - partial_sums[lid][1] = out_texels[1]; - - memoryBarrierShared(); - barrier(); - - // Tree reduction to compute the overall result. - for (int i = WGS / 2; i > 0; i /= 2) { - if (lid < i) { - partial_sums[lid][0] = partial_sums[lid][0] + partial_sums[lid + i][0]; - partial_sums[lid][1] = partial_sums[lid][1] + partial_sums[lid + i][1]; - } - memoryBarrierShared(); - barrier(); - } - - // Only the first thread will write out result - if (lid == 0) { - out_texels[0] = partial_sums[0][0]; - out_texels[1] = partial_sums[0][1]; - - uint n4 = DIV_4(n); - write_output_texel_1d(out_texels[0], n4); - if (n + 4 < output_sizes.x) { - write_output_texel_1d(out_texels[1], n4 + 1); - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml deleted file mode 100644 index 04e803a2e94..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_coop.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -linear_qga4w_coop: - parameter_names_with_default_values: - DTYPE: float - IO_STORAGE: texture3d - WEIGHT_STORAGE: texture2d - WGS: 64 - shader_variants: - - NAME: linear_qga4w_coop_texture3d_texture3d_texture2d_float - - NAME: linear_qga4w_coop_buffer_buffer_texture2d_float - IO_STORAGE: buffer - - NAME: linear_qga4w_coop_buffer_buffer_buffer_float - IO_STORAGE: buffer - WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl deleted file mode 100644 index 97327ea5818..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.glsl +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${texel_load_component_type(DTYPE, IO_STORAGE)} -#define VEC4_T ${texel_load_type(DTYPE, IO_STORAGE)} - -${define_required_extensions(DTYPE)} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_output", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input", DTYPE, IO_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qparams", DTYPE, "buffer", is_scalar_array=False)} - -layout(push_constant) uniform restrict Block { - ivec4 output_sizes; - ivec4 input_sizes; - ivec4 weight_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int group_size = 64; - -$if IO_STORAGE == "buffer": - #define BUFFER_IO -$if WEIGHT_STORAGE == "buffer": - #define BUFFER_WEIGHT - -#include "qlinear_utils.glslh" - -void main() { - // Each thread writes out a 8 wide x 4 high tile of output values - const uint n8 = gl_GlobalInvocationID.x; - const uint m4 = gl_GlobalInvocationID.y; - - const uint n = MUL_8(n8); // output col idx - const uint m = MUL_4(m4); // output row idx - const uint n4 = MUL_2(n8); // output col texel idx - - const uint group_num = input_sizes.x / group_size; - const uint group_ntexels = DIV_UP_4(group_size); - - if (n >= output_sizes.x || m >= output_sizes.y) { - return; - } - - const uint K4 = DIV_UP_4(input_sizes.x); - const uint N4 = DIV_UP_4(output_sizes.x); // number of texels in each row - - VEC4_T out_texels[4][2]; - // Initialize to 0 - $for row_i in range(4): - $for col_i in range(2): - out_texels[${row_i}][${col_i}] = VEC4_T(0.00); - - for (uint group_i = 0; group_i < group_num; ++group_i) { - // Load quantization scales and zeros for the current group - VEC4_T scales[2]; - VEC4_T zeros[2]; - { - uint qparams_bufi = group_i * DIV_2(output_sizes.x) + DIV_2(n); - - VEC4_T scales_zeros_texels[4]; - $for comp in range(4): - scales_zeros_texels[${comp}] = t_qparams[qparams_bufi++]; - - scales[0] = VEC4_T(scales_zeros_texels[0].xz, scales_zeros_texels[1].xz); - zeros[0] = VEC4_T(scales_zeros_texels[0].yw, scales_zeros_texels[1].yw); - - scales[1] = VEC4_T(scales_zeros_texels[2].xz, scales_zeros_texels[3].xz); - zeros[1] = VEC4_T(scales_zeros_texels[2].yw, scales_zeros_texels[3].yw); - } - - for (uint inner_k4 = 0; inner_k4 < group_ntexels; inner_k4++) { - const uint k4 = group_i * group_ntexels + inner_k4; - - // Load 4x4 block of the input tensor, with the top left corner of the - // block at (k, m) - VEC4_T in_texels[4]; - $for comp in range(4): - in_texels[${comp}] = load_input_texel_2d(k4, m + ${comp}, K4); - - uvec4 packed_weight_block = load_transposed_weight_block(k4, n8, K4); - - VEC4_T weight_texels[2]; - $for tile_k in range(4): - // Process weight row k + comp - { - // Weight columns n + 0, 1, 2, 3 - weight_texels[0].x = extract_4bit_from_transposed_block(packed_weight_block, 0, ${tile_k}); - weight_texels[0].y = extract_4bit_from_transposed_block(packed_weight_block, 1, ${tile_k}); - weight_texels[0].z = extract_4bit_from_transposed_block(packed_weight_block, 2, ${tile_k}); - weight_texels[0].w = extract_4bit_from_transposed_block(packed_weight_block, 3, ${tile_k}); - - // Weight colums n + 4, 5, 6, 7 - weight_texels[1].x = extract_4bit_from_transposed_block(packed_weight_block, 4, ${tile_k}); - weight_texels[1].y = extract_4bit_from_transposed_block(packed_weight_block, 5, ${tile_k}); - weight_texels[1].z = extract_4bit_from_transposed_block(packed_weight_block, 6, ${tile_k}); - weight_texels[1].w = extract_4bit_from_transposed_block(packed_weight_block, 7, ${tile_k}); - - weight_texels[0] = fma(weight_texels[0], scales[0], zeros[0]); - weight_texels[1] = fma(weight_texels[1], scales[1], zeros[1]); - - $for tile_m in range(4): - out_texels[${tile_m}][0] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[0], out_texels[${tile_m}][0]); - out_texels[${tile_m}][1] = fma(VEC4_T(in_texels[${tile_m}][${tile_k}]), weight_texels[1], out_texels[${tile_m}][1]); - } - } - } - - for (uint row_i = 0; row_i < 4 && m + row_i < output_sizes.y; ++row_i) { - write_output_texel_2d(out_texels[row_i][0], n4, m + row_i, N4); - if (n + 4 < output_sizes.x) { - write_output_texel_2d(out_texels[row_i][1], n4 + 1, m + row_i, N4); - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml deleted file mode 100644 index 94d10dcf978..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qga4w_tiled.yaml +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -linear_qga4w_tiled: - parameter_names_with_default_values: - DTYPE: float - IO_STORAGE: texture3d - WEIGHT_STORAGE: texture2d - shader_variants: - - NAME: linear_qga4w_tiled_texture3d_texture3d_texture2d_float - - NAME: linear_qga4w_tiled_buffer_buffer_texture2d_float - IO_STORAGE: buffer - - NAME: linear_qga4w_tiled_buffer_buffer_buffer_float - IO_STORAGE: buffer - WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl deleted file mode 100644 index 174ea1cc9bb..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.glsl +++ /dev/null @@ -1,232 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} - -#define TILE_ROWS ${TILE_ROWS} - -#define NGROUPS 8 -#define NWORKERS 8 - -${define_required_extensions(DTYPE)} -$if IN_STORAGE == "buffer": - ${define_required_extensions("int8")} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input_scale", "float", PARAMS_STORAGE, is_scalar_array=True)} -${layout_declare_tensor(B, "r", "t_input_zero_point", "int", PARAMS_STORAGE, is_scalar_array=True)} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int group_size = 64; - -shared vec4 partial_results[NGROUPS][NWORKERS][TILE_ROWS][2]; - -/* - * This shader computes a linear operator between a quantized int8 input matrix - * x and a weights matrix that is quantized to 4 bits, producing a float output. - * - * This shader implements a co-operative algorithm to compute the output. The - * work group size is {NGROUP, 1, NWORKERS}, and each group of NWORKERS threads - * cooperative to compute TILE_ROWS * 2 output texels. Therefore, - * NGROUP * TILE_ROWS * 2 output texels are computed across one work group. - * - * The threads co-operate by each thread computing a partial reduction along the - * K dimension. To illustrate the computation, consider a scalar variant of the - * algorithm that computes the dot product of 2 vectors. Also assume that - * NWORKERS is 8. - * - * Thread 1 in each group will compute: - * (mat1[0] * mat2[0]) + (mat1[8] * mat2[8]) + (mat1[16] * mat2[16]) + ... - * - * Thread 2 in each group will compute: - * (mat1[1] * mat2[1]) + (mat2[9] * mat2[9]) + (mat1[17] * mat2[17]) + ... - * - * Thread 3 in each group will compute: - * (mat1[2] * mat2[2]) + (mat2[10] * mat2[10]) + (mat1[18] * mat2[18]) + ... - * - * The partial accumulations is structured such that memory accesses in each - * loop iteration can be coalesced. - * - * Then, at the end first thread in each group will accumulate the partial - * accumulations computed by each thread to obtain the final result. - * - * Note that this shader assumes that all tensors are width packed. - */ - -void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - const uint out_col = gl_GlobalInvocationID.x << 3; - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - const uint gid = gl_LocalInvocationID.x; // group id - const uint wid = gl_LocalInvocationID.z; // worker id - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { - return; - } - - const int num_blocks = mat1_sizes.x / group_size; - - ivec4 mat1_quantized[TILE_ROWS]; - ivec4 qmat2_quantized[4][2]; - vec4 final_result[TILE_ROWS][2]; - - // Initialize accumulators - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - final_result[r][0] = vec4(0.0); - final_result[r][1] = vec4(0.0); - } - - vec4 scales[2]; - vec4 zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_stride = out_sizes.x >> 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; - scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; - - zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); - zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); - $else: - scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); - scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); - - zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); - zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); - - ivec4 int32_sums[TILE_ROWS][2]; - int input_sums[TILE_ROWS]; - - // Initialize accumulators for this block - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - int32_sums[r][0] = ivec4(0); - int32_sums[r][1] = ivec4(0); - input_sums[r] = 0; - } - - for (int g_idx = 4 * int(wid); g_idx < group_size; g_idx += (4 * NWORKERS)) { - const int k = block_idx * group_size + g_idx; - - // Preload B (weights) - keep as quantized integers - [[unroll]] for (int r = 0; r < 4; ++r) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + r), - 0); - - // Unpack 4-bit weights to integers and subtract zero point (8 for 4-bit) - qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; - qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; - } - - // Preload A (quantized input) - keep as quantized integers - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if IN_STORAGE == "buffer": - mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; - $else: - mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; - } - - // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; - - int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] - + mat1_quantized[r].y * qmat2_quantized[1][0] - + mat1_quantized[r].z * qmat2_quantized[2][0] - + mat1_quantized[r].w * qmat2_quantized[3][0]; - - int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] - + mat1_quantized[r].y * qmat2_quantized[1][1] - + mat1_quantized[r].z * qmat2_quantized[2][1] - + mat1_quantized[r].w * qmat2_quantized[3][1]; - } - } - - // Incorporates this block's results into the final accumulation - // Following proper quantization paradigm: result = input_scale * weight_scale * - // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - if (out_row + r >= out_sizes.y) { - continue; - } - - float input_scale = t_input_scale[int(out_row) + r]; - float input_sum_scalar = float(input_sums[r]); - - // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) - final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); - final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); - } - } - - // Store worker results in shared memory - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - partial_results[gid][wid][r][0] = final_result[r][0]; - partial_results[gid][wid][r][1] = final_result[r][1]; - } - - memoryBarrierShared(); - barrier(); - - // Only the first worker in each group accumulates and writes output - if (wid != 0) { - return; - } - - vec4 cooperative_result[TILE_ROWS][2]; - - for (int r = 0; r < TILE_ROWS; ++r) { - cooperative_result[r][0] = vec4(0.0); - cooperative_result[r][1] = vec4(0.0); - [[unroll]] for (int worker = 0; worker < NWORKERS; ++worker) { - cooperative_result[r][0] += partial_results[gid][worker][r][0]; - cooperative_result[r][1] += partial_results[gid][worker][r][1]; - } - } - - // Apply final output quantization - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if OUT_STORAGE == "buffer": - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = cooperative_result[r][0]; - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = cooperative_result[r][1]; - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), cooperative_result[r][0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), cooperative_result[r][1]); - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml deleted file mode 100644 index 9f6db77094a..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_coop.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -linear_qta8a_qga4w_coop: - parameter_names_with_default_values: - DTYPE: float - OUT_STORAGE: texture3d - IN_STORAGE: texture3d - WEIGHT_STORAGE: texture2d - PARAMS_STORAGE: buffer - TILE_ROWS: 1 - shader_variants: - - NAME: linear_qta8a_qga4w_coop_texture3d_texture3d_texture2d_float - - NAME: linear_qta8a_qga4w_coop_buffer_buffer_texture2d_float - OUT_STORAGE: buffer - IN_STORAGE: buffer - - NAME: linear_qta8a_qga4w_coop_buffer_buffer_buffer_float - OUT_STORAGE: buffer - IN_STORAGE: buffer - WEIGHT_STORAGE: buffer - - NAME: linear_qta8a_qga4w_coop_buffer_texture2d_buffer_float - OUT_STORAGE: buffer - WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl deleted file mode 100644 index dbb7da998f4..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.glsl +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -#define T ${buffer_scalar_type(DTYPE)} -#define VEC4_T ${buffer_gvec_type(DTYPE, 4)} - -#define TILE_ROWS ${TILE_ROWS} - -${define_required_extensions(DTYPE)} -$if IN_STORAGE == "buffer": - ${define_required_extensions("int8")} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("uint8")} - -#extension GL_EXT_control_flow_attributes : require - -layout(std430) buffer; - -${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_mat1", "int8", IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_scales", "float", PARAMS_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_weight_zeros", "int", PARAMS_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} -${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} - -layout(push_constant) uniform restrict Block { - ivec4 out_sizes; - ivec4 mat1_sizes; - ivec4 qmat2_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -layout(constant_id = 3) const int group_size = 64; - -/* - * This shader computes a linear operator between a quantized int8 input matrix - * x and a weights matrix that is quantized to 4 bits, producing a float output. - * - * The (W, H, C) shape of each tensor is: - * - x: (K, M) - quantized int8 input with per-token quantization - * - weights: (N / 2, K) - * - The weights tensor has a data type of `uint8`. Each element in the tensor - * contains 2 4-bit values packed into a uint8. - * - See the pack_int4_linear_weight_transposed_interleave shader to see more - * details on how the weight tensor is stored. - * - qparams: (2, N, number_of_groups) - * - This tensor contains the scales and zeros quantization parameters for the - * weights tensor. The weight tensor is quantized group-wise, which means - * that every `group_size` elements along the K dimension of the weights - * tensor has independent quantization parameters. Along the width dim, the - * first value contains the scale for the group and the second value - * contains the zero point for the group. - * - input_scale: (num_tokens,) - per-token scale values for input quantization - * - input_zero_point: (num_tokens,) - per-token zero points for input quantization - * - output: (N, M) - float output - * - * Each thread computes a tile of TILE_ROWS * 2 texels of the output tensor. - * - * Note that this shader assumes that all tensors are width packed. - */ - -void main() { - const uint out_row = gl_GlobalInvocationID.y * TILE_ROWS; - const uint out_col = gl_GlobalInvocationID.x << 3; - const int out_col_texel_idx = int(gl_GlobalInvocationID.x) << 1; - - if (out_col >= out_sizes.x || out_row >= out_sizes.y) { - return; - } - - const int num_blocks = mat1_sizes.x / group_size; - - ivec4 mat1_quantized[TILE_ROWS]; - ivec4 qmat2_quantized[4][2]; - vec4 final_result[TILE_ROWS][2]; - - // Initialize accumulatoxrs - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - final_result[r][0] = vec4(0.0); - final_result[r][1] = vec4(0.0); - } - - vec4 scales[2]; - vec4 zeros[2]; - - $if WEIGHT_STORAGE == "buffer": - const int qmat2_stride = qmat2_sizes.x >> 2; - $if PARAMS_STORAGE == "buffer": - const int qparams_stride = out_sizes.x >> 2; - - for (int block_idx = 0; block_idx < num_blocks; ++block_idx) { - $if PARAMS_STORAGE == "buffer": - scales[0] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx]; - scales[1] = t_weight_scales[block_idx * qparams_stride + out_col_texel_idx + 1]; - - zeros[0] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx]); - zeros[1] = vec4(t_weight_zeros[block_idx * qparams_stride + out_col_texel_idx + 1]); - $else: - scales[0] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx, block_idx, 0), 0); - scales[1] = texelFetch(t_weight_scales, ivec3(out_col_texel_idx + 1, block_idx, 0), 0); - - zeros[0] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx, block_idx, 0), 0)); - zeros[1] = vec4(texelFetch(t_weight_zeros, ivec3(out_col_texel_idx + 1, block_idx, 0), 0)); - - ivec4 int32_sums[TILE_ROWS][2]; - int input_sums[TILE_ROWS]; - - // Initialize accumulators - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - int32_sums[r][0] = ivec4(0); - int32_sums[r][1] = ivec4(0); - input_sums[r] = 0; - } - - for (int g_idx = 0; g_idx < group_size; g_idx += 4) { - const int k = block_idx * group_size + g_idx; - - // Preload B (weights) - keep as quantized integers - [[unroll]] for (int r = 0; r < 4; ++r) { - $if WEIGHT_STORAGE == "buffer": - const u8vec4 packed_weight_tex = t_qmat2[(k + r) * qmat2_stride + gl_GlobalInvocationID.x]; - $else: - const uvec4 packed_weight_tex = texelFetch( - t_qmat2, - ivec2(gl_GlobalInvocationID.x, k + r), - 0); - - // Unpack 4-bit weights to integers (subtract 8 as the 4-bit zero point) - qmat2_quantized[r][0] = ivec4((packed_weight_tex & 0xF0) >> 4) - 8; - qmat2_quantized[r][1] = ivec4(packed_weight_tex & 0x0F) - 8; - } - - // Preload A (quantized input) - keep as quantized integers - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if IN_STORAGE == "buffer": - mat1_quantized[r] = t_mat1[((out_row + r) * mat1_sizes.x + k) >> 2] - t_input_zero_point[int(out_row) + r]; - $else: - mat1_quantized[r] = texelFetch(t_mat1, ivec3(k >> 2, out_row + r, 0), 0) - t_input_zero_point[int(out_row) + r]; - } - - // Accumulate in integer arithmetic: (input_quantized - input_zero_point) * (weight_quantized - weight_zero_point) - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - input_sums[r] += mat1_quantized[r].x + mat1_quantized[r].y + mat1_quantized[r].z + mat1_quantized[r].w; - - int32_sums[r][0] += mat1_quantized[r].x * qmat2_quantized[0][0] - + mat1_quantized[r].y * qmat2_quantized[1][0] - + mat1_quantized[r].z * qmat2_quantized[2][0] - + mat1_quantized[r].w * qmat2_quantized[3][0]; - - int32_sums[r][1] += mat1_quantized[r].x * qmat2_quantized[0][1] - + mat1_quantized[r].y * qmat2_quantized[1][1] - + mat1_quantized[r].z * qmat2_quantized[2][1] - + mat1_quantized[r].w * qmat2_quantized[3][1]; - } - } - - // Incorporates this block's results into the final accumulation - // Following proper quantization paradigm: result = input_scale * weight_scale * - // Sum((input_quantized - input_zero) * (weight_quantized - weight_zero)) - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - if (out_row + r >= out_sizes.y) { - continue; - } - - float input_scale = t_input_scale[int(out_row) + r]; - float input_sum_scalar = float(input_sums[r]); - - // Apply proper quantization paradigm: input_scale * weight_scale * (accumulator - weight_zero * input_sum) - final_result[r][0] += input_scale * scales[0] * (vec4(int32_sums[r][0]) - zeros[0] * input_sum_scalar); - final_result[r][1] += input_scale * scales[1] * (vec4(int32_sums[r][1]) - zeros[1] * input_sum_scalar); - } - } - - // Apply ALL scaling at the very end - [[unroll]] for (int r = 0; r < TILE_ROWS; ++r) { - $if OUT_STORAGE == "buffer": - if (out_row + r < out_sizes.y) { - t_out[((out_row + r) * out_sizes.x + out_col) >> 2] = final_result[r][0]; - t_out[((out_row + r) * out_sizes.x + out_col + 4) >> 2] = final_result[r][1]; - } - $else: - imageStore(t_out, ivec3(out_col_texel_idx, out_row + r, 0), final_result[r][0]); - imageStore(t_out, ivec3(out_col_texel_idx + 1, out_row + r, 0), final_result[r][1]); - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml deleted file mode 100644 index c96d693834b..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_tiled.yaml +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -linear_qta8a_qga4w_tiled: - parameter_names_with_default_values: - DTYPE: float - OUT_STORAGE: texture3d - IN_STORAGE: texture3d - WEIGHT_STORAGE: texture2d - PARAMS_STORAGE: buffer - TILE_ROWS: 3 - shader_variants: - - NAME: linear_qta8a_qga4w_tiled_texture3d_texture3d_texture2d_float - - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_texture2d_float - OUT_STORAGE: buffer - IN_STORAGE: buffer - - NAME: linear_qta8a_qga4w_tiled_buffer_buffer_buffer_float - OUT_STORAGE: buffer - IN_STORAGE: buffer - WEIGHT_STORAGE: buffer - - NAME: linear_qta8a_qga4w_tiled_buffer_texture2d_buffer_float - OUT_STORAGE: buffer - WEIGHT_STORAGE: buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl deleted file mode 100644 index e42cf05dd7f..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.glsl +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -layout(std430) buffer; - -#include "indexing_utils.h" - -${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_input", "uint", "buffer")} - -layout(push_constant) uniform restrict Block { - ivec4 qmat2_sizes; - ivec2 orig_sizes; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -$if STORAGE == "buffer": - #define BUFFER_WEIGHT - -#include "qlinear_weight_pack_utils.glslh" - -#define extract_4bit(input_block_data, col, row) \ - (extract_4bit_from_packed_uint_le(input_block_data[row], col)) - -/* - * This shader packs the weight tensor into blocks for efficient consumption. - * - * The input tensor has shape [K/2, N] where each element is a uint8 containing - * 2 packed 4-bit values. The logical tensor shape is [K, N] of 4-bit values. - * - * The transformation partitions the tensor into blocks of size 4x8 (4-bit values) - * and transposes each block to 8x4, then packs the result so that each uvec4 - * contains an entire transposed block. - * - * Original block (4x8 4-bit values, shown as 2x8 uint8 values): - * w00|w10, w20|w30, - * w01|w11, w21|w31, - * w02|w12, w22|w32, - * w03|w13, w23|w33, - * w04|w14, w24|w34, - * w05|w15, w25|w35, - * w06|w16, w26|w36, - * w07|w17, w27|w37, - * - * Transposed block (8x4 4-bit values, packed into uvec4): - * w00|w01, w02|w03, w04|w05, w06|w07 - * w10|w11, w12|w13, w14|w15, w16|w17 - * w20|w21, w22|w23, w24|w25, w26|w27 - * w30|w31, w32|w33, w34|w35, w36|w37 - */ -void main() { - // Each thread writes out 2 adjacent 8 wide x 4 high transposed block. Each - // block is packed as one uvec4. - ivec2 block_pos = ivec2( - MUL_2(gl_GlobalInvocationID.x), - gl_GlobalInvocationID.y); - - // There are K wide x N high 4-bit values in the original weight tensor - const int input_width = orig_sizes.x; // K - const int input_height = orig_sizes.y; // N - - const int input_width_uint = DIV_UP_8(input_width); - - // Original block spans 4 wide x 8 high 4-bit values. Since uint is used to - // read the input tensor, each block spans 0.5 wide x 8 high uint values. - const ivec2 block_start = ivec2( - DIV_2(block_pos.x), - MUL_8(block_pos.y)); - - // Check bounds - if (block_start.x >= input_width_uint || block_start.y >= input_height) { - return; - } - - // Read input block. Note that this block will contain the source data for - // both output blocks, as it contains 1 wide x 8 high uint values, which is - // equivalent to 8 wide x 8 high 4-bit values. - uint input_block_data[8]; - - // Read in 8 rows along the same column of uints, each uint contains 4 4-bit - // values. This will be the source data for the transposed block. - for (int i = 0; i < 8; ++i) { - uint input_bufi = (block_start.y + i) * input_width_uint + block_start.x; - input_block_data[i] = t_input[input_bufi]; - } - - for (int col_offset = 0; col_offset <= 4; col_offset+=4) { - uvec4 output_block; - - output_block.x = pack_8x4bit_into_uint( - extract_4bit(input_block_data, col_offset, 0), - extract_4bit(input_block_data, col_offset, 1), - extract_4bit(input_block_data, col_offset, 2), - extract_4bit(input_block_data, col_offset, 3), - extract_4bit(input_block_data, col_offset, 4), - extract_4bit(input_block_data, col_offset, 5), - extract_4bit(input_block_data, col_offset, 6), - extract_4bit(input_block_data, col_offset, 7)); - - output_block.y = pack_8x4bit_into_uint( - extract_4bit(input_block_data, col_offset + 1, 0), - extract_4bit(input_block_data, col_offset + 1, 1), - extract_4bit(input_block_data, col_offset + 1, 2), - extract_4bit(input_block_data, col_offset + 1, 3), - extract_4bit(input_block_data, col_offset + 1, 4), - extract_4bit(input_block_data, col_offset + 1, 5), - extract_4bit(input_block_data, col_offset + 1, 6), - extract_4bit(input_block_data, col_offset + 1, 7)); - - output_block.z = pack_8x4bit_into_uint( - extract_4bit(input_block_data, col_offset + 2, 0), - extract_4bit(input_block_data, col_offset + 2, 1), - extract_4bit(input_block_data, col_offset + 2, 2), - extract_4bit(input_block_data, col_offset + 2, 3), - extract_4bit(input_block_data, col_offset + 2, 4), - extract_4bit(input_block_data, col_offset + 2, 5), - extract_4bit(input_block_data, col_offset + 2, 6), - extract_4bit(input_block_data, col_offset + 2, 7)); - - output_block.w = pack_8x4bit_into_uint( - extract_4bit(input_block_data, col_offset + 3, 0), - extract_4bit(input_block_data, col_offset + 3, 1), - extract_4bit(input_block_data, col_offset + 3, 2), - extract_4bit(input_block_data, col_offset + 3, 3), - extract_4bit(input_block_data, col_offset + 3, 4), - extract_4bit(input_block_data, col_offset + 3, 5), - extract_4bit(input_block_data, col_offset + 3, 6), - extract_4bit(input_block_data, col_offset + 3, 7)); - - const uint qmat2_texel_stride_x = DIV_UP_4(qmat2_sizes.x); - write_transposed_weight_block( - output_block, - block_pos.x, - block_pos.y, - qmat2_texel_stride_x); - - if (MUL_8(block_start.x) + 4 >= input_width) { - return; - } - // Otherwise, implement the block position to write to the next block in the - // following iteration. - block_pos.x += 1; - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml deleted file mode 100644 index c72a2cc1df6..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_block_4x8.yaml +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -pack_int4_linear_weight_transposed_block_4x8: - parameter_names_with_default_values: - STORAGE: buffer - shader_variants: - - NAME: pack_int4_linear_weight_transposed_block_4x8_buffer - STORAGE: buffer - - NAME: pack_int4_linear_weight_transposed_block_4x8_texture2d - STORAGE: texture2d diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh deleted file mode 100644 index 80ec44c153a..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/qlinear_utils.glslh +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef QLINEAR_UTILS_H -#define QLINEAR_UTILS_H - -/*********************************** - * Packed Weight data read/write functions - * - * These functions assume that t_qmat2 is declared in the shader layout as a storage - * buffer or storage image. - */ - -#ifdef BUFFER_WEIGHT - -uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { - return t_qmat2[n8 * K4 + k4]; -} - -#else // TEXTURE_WEIGHT - -uvec4 load_transposed_weight_block(const uint k4, const uint n8, const uint K4) { - return texelFetch(t_qmat2, ivec2(k4, n8), 0); -} - -#endif // BUFFER_WEIGHT - -/*********************************** - * Packed weight data extraction functions - */ - -/* - * uvec4 block contains a packed 4 high x 8 wide matrix of 4-bit signed integers. This - * function extracts the 4-bit values at the given column and row index. - * - * Each uint in the uvec4 corresponds to one row; thus the desired row can be extracted - * via block[row]. From there, column 0 is packed in bits 28-31, column 1 is packed into - * bits 24-27, column 3 is packed into bits 20-23, and so on. To extract the desired - * value: - * - * 1. First, shift the row uint by 4 * (7 - col) bits - * 2. Apply a mask of 0b1111 = 15 - * - * Finally, convert the masked value to int and subtract it by int to obtain the desired - * signed integer. - */ -T extract_4bit_from_transposed_block(const uvec4 block, const uint col, const uint row) { - return T(int((block[row] >> (4 * (7 - col))) & 15) - 8); -} - -/*********************************** - * Input/Output read/write functions - * - * These functions assume that t_input and t_output are declared in the shader layout as - * storage buffers or storage images. - */ - -#ifdef BUFFER_IO - -VEC4_T load_input_texel_1d(const uint k4) { - return t_input[k4]; -} - -VEC4_T load_input_texel_2d( - const uint k4, - const uint m, - const uint K4) { - return t_input[(m * K4) + k4]; -} - -void write_output_texel_1d(const VEC4_T out_texel, const uint n4) { - t_output[n4] = out_texel; -} - -void write_output_texel_2d( - const VEC4_T out_texel, - const uint n4, - const uint m, - const uint N4) { - t_output[m * N4 + n4] = out_texel; -} - -#else // TEXTURE_IO - -VEC4_T load_input_texel_1d(const uint k4) { - return texelFetch(t_input, ivec3(k4, 0, 0), 0); -} - -VEC4_T load_input_texel_2d( - const uint k4, - const uint m, - const uint K4) { - return texelFetch(t_input, ivec3(k4, m, 0), 0); -} - - -void write_output_texel_1d(const VEC4_T out_texel, const uint n4) { - imageStore(t_output, ivec3(n4, 0, 0), out_texel); -} - -void write_output_texel_2d( - const VEC4_T out_texel, - const uint n4, - const uint m, - const uint N4) { - imageStore(t_output, ivec3(n4, m, 0), out_texel); -} - -#endif // BUFFER_IO - -#endif // QLINEAR_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh deleted file mode 100644 index 1f481f4f859..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/qlinear_weight_pack_utils.glslh +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#ifndef QLINEAR_WEIGHT_PACK_UTILS_H -#define QLINEAR_WEIGHT_PACK_UTILS_H - -/*********************************** - * Packed Weight data write functions - * - * These functions assume that t_qmat2 has been defined in the shader layout as either - * a storage buffer or a storage image. - */ - -#ifdef BUFFER_WEIGHT - -void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { - t_qmat2[n8 * K4 + k4] = block; -} - -#else // TEXTURE_WEIGHT - -void write_transposed_weight_block(const uvec4 block, const uint k4, const uint n8, const uint K4) { - imageStore(t_qmat2, ivec2(k4, n8), block); -} - -#endif // BUFFER_WEIGHT - -/*********************************** - * Utilities for packing weight data - */ - -uint extract_4bit_from_packed_uint_le(const uint packed, const uint i) { - // account for little endian - uint byte = packed >> (8 * (i / 2)) & 255; - return (byte >> (4 - 4 * (i % 2))) & 15; -} - -uint pack_8x4bit_into_uint( - const uint val0, - const uint val1, - const uint val2, - const uint val3, - const uint val4, - const uint val5, - const uint val6, - const uint val7) { - return uint( - (val0 << 28) | (val1 << 24) | (val2 << 20) | (val3 << 16) | (val4 << 12) | - (val5 << 8) | (val6 << 4) | val7 - ); -} - -#endif // QLINEAR_WEIGHT_PACK_UTILS_H diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp deleted file mode 100644 index 52cf75e28b5..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQGANW.cpp +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include - -#include -#include - -namespace vkcompute { - -void check_linear_qga4w_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros, - const ValueRef out) { - VK_CHECK_COND(graph.val_is_tensor(mat1)); - VK_CHECK_COND(graph.val_is_tref(mat2_data)); - VK_CHECK_COND(graph.val_is_tref(scales_and_zeros)); - - VK_CHECK_COND(graph.dim_of(mat1) <= 3); - VK_CHECK_COND(graph.dim_of(mat2_data) == 2); - VK_CHECK_COND(graph.dim_of(scales_and_zeros) == 3); - - VK_CHECK_COND(graph.size_at(-3, mat1) == 1); - const int K = graph.size_at(-1, mat1); - VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); - - const int group_size_val = graph.extract_scalar(group_size); - VK_CHECK_COND(K % group_size_val == 0); - // Due to the way weight packing works, group size needs to be a multiple of 8 - VK_CHECK_COND(group_size_val % 8 == 0); - - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); - VK_CHECK_COND(graph.has_standard_axis_map(out)); -} - -void resize_linear_qga4w_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - ValueRef out = args.at(0).refs.at(0); - ValueRef mat1 = args.at(1).refs.at(0); - ValueRef mat2_data = extra_args.at(0); - - std::vector mat1_sizes = graph->sizes_of(mat1); - std::vector mat2_sizes = graph->sizes_of(mat2_data); - - const int64_t out_cols = utils::val_at(-2, mat1_sizes); - const int64_t out_rows = utils::val_at(-2, mat2_sizes); - - std::vector new_out_sizes(3); - if (mat1_sizes.size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1_sizes.at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - graph->virtual_resize(out, new_out_sizes); -} - -/** - * Determines if the cooperative algorithm should be used based on input tensor - * dimensions. Apply the coop algorithm for gemv cases, i.e. mat1 is avector as - * as opposed to a matrix. - */ -bool should_use_coop_algorithm(ComputeGraph* graph, const ValueRef& mat1) { - return graph->size_at(-2, mat1) == 1; -} - -vkapi::ShaderInfo pick_linear_qga4w_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1 = args.at(1).refs.at(0); - const ValueRef mat2 = args.at(1).refs.at(1); - - const bool use_coop_algorithm = should_use_coop_algorithm(graph, mat1); - - std::string kernel_name = "linear_qga4w"; - if (use_coop_algorithm) { - kernel_name += "_coop"; - } else { - kernel_name += "_tiled"; - } - add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); - add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); - add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); - add_dtype_suffix(kernel_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(kernel_name); -} - -utils::uvec3 linear_qga4w_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (!use_coop_algorithm) { - // Constructing the global workgroup size for the tiled algorithm - utils::uvec3 global_wg_size = graph->logical_limits_of(out); - // Each shader thread computes a 4 high x 8 wide tile of the output matrix, - // which is equivalent to 4 x 2 texels. Since the output tensor must be - // width packed, div-up the "texel-width" of the output by 2 and the height - // of the output tensor by 4 to obtain the number of tiles that need to be - // computed. - global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(4)); - return global_wg_size; - } - - uint32_t output_channels = graph->size_at(-1, out); - uint32_t batch_size = graph->size_at(-2, out); - - // Constructing the global workgroup size of the co-operative algorithm. The - // local work group size is 64, and each local work group co-operates to - // compute 8 output channels of the output. Therefore, a total of - // (output_channels / 8 x 64) threads should be launched, assuming a batch - // size of 1. - return {64, utils::div_up(output_channels, 8u), batch_size}; -} - -utils::uvec3 linear_qga4w_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)args; - (void)resize_args; - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - if (use_coop_algorithm) { - return {64, 1, 1}; - } else { - return pick_hw_square_wg_size( - graph, shader, global_workgroup_size, args, resize_args); - } -} - -void add_linear_qga4w_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef scales_and_zeros_data, - const ValueRef out) { - check_linear_qga4w_args( - graph, mat1, mat2_data, group_size, scales_and_zeros_data, out); - - const uint32_t group_size_val = graph.extract_scalar(group_size); - - ValueRef mat2 = - prepack_int4_linear_weight_transposed_block_4x8(graph, mat2_data); - - ValueRef scales_and_zeros = prepack_standard( - graph, scales_and_zeros_data, utils::kBuffer, utils::kWidthPacked); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - pick_linear_qga4w_shader, - linear_qga4w_global_wg_size, - linear_qga4w_local_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, {{mat1, mat2, scales_and_zeros}, vkapi::kRead}}, - // Shader params buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(mat1), - graph.sizes_pc_of(mat2)}, - // Specialization Constants - {SV(group_size_val)}, - // Resize Args - {mat2_data}, - // Resizing Logic - resize_linear_qga4w_node)); -} - -void linear_weight_int4( - ComputeGraph& graph, - const std::vector& args) { - return add_linear_qga4w_node( - graph, - args[0], // mat1 - args[1], // mat2 - args[2], // group_size - args[3], // scales_and_zeros - // There is an unused variable inner_k_tiles which is used to call - // _convert_weight_to_int4pack in the AOT custom op, which is why the 4th - // argument is skipped. - args[5] // out - ); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(et_vk.linear_weight_int4.default, linear_weight_int4); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp deleted file mode 100644 index e3443ca34e6..00000000000 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinear_QTA8A_QGA4W.cpp +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include - -#include -#include - -#include -#include - -namespace vkcompute { - -void check_linear_qta8a_qga4w_args( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat1_scale, - const ValueRef mat1_zero_point, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef weight_scales, - const ValueRef weight_zeros, - const ValueRef out) { - VK_CHECK_COND(graph.val_is_tensor(mat1)); - VK_CHECK_COND(graph.val_is_tensor(mat1_scale)); - VK_CHECK_COND(graph.val_is_tensor(mat1_zero_point)); - VK_CHECK_COND(graph.val_is_tref(mat2_data)); - VK_CHECK_COND(graph.val_is_tref(weight_scales)); - VK_CHECK_COND(graph.val_is_tref(weight_zeros)); - - VK_CHECK_COND(graph.dim_of(mat1) <= 3); - VK_CHECK_COND(graph.dim_of(mat2_data) == 2); - VK_CHECK_COND(graph.dim_of(weight_scales) == 2); - VK_CHECK_COND(graph.dim_of(weight_zeros) == 2); - - VK_CHECK_COND(graph.size_at(-3, mat1) == 1); - const int K = graph.size_at(-1, mat1); - VK_CHECK_COND(graph.size_at(-1, mat2_data) * 2 == K); - - const int group_size_val = graph.extract_scalar(group_size); - VK_CHECK_COND(K % group_size_val == 0); - // Due to the way weight packing works, group size needs to be a multiple of 8 - VK_CHECK_COND(group_size_val % 8 == 0); - - VK_CHECK_COND(graph.has_standard_axis_map(mat1)); - VK_CHECK_COND(graph.has_standard_axis_map(out)); - - // Check that scale and zero_point tensors are buffer storage with width - // packing - VK_CHECK_COND(graph.is_buffer_storage(mat1_scale)); - VK_CHECK_COND(graph.packed_dim_of(mat1_scale) == WHCN::kWidthDim); - VK_CHECK_COND(graph.is_buffer_storage(mat1_zero_point)); - VK_CHECK_COND(graph.packed_dim_of(mat1_zero_point) == WHCN::kWidthDim); - - // Calculate number of tokens for input - int64_t input_num_tokens = 1; - const auto mat1_sizes = graph.sizes_of(mat1); - for (size_t i = 0; i < mat1_sizes.size() - 1; i++) { - input_num_tokens *= mat1_sizes[i]; - } - - // Verify scale and zero_point tensor sizes match number of tokens - const auto mat1_scale_sizes = graph.sizes_of(mat1_scale); - const auto mat1_zero_point_sizes = graph.sizes_of(mat1_zero_point); - - VK_CHECK_COND( - utils::val_at(-1, mat1_scale_sizes) == input_num_tokens); - VK_CHECK_COND( - utils::val_at(-1, mat1_zero_point_sizes) == input_num_tokens); - - // Verify weight scales and zeros have the same shape - const auto weight_scales_sizes = graph.sizes_of(weight_scales); - const auto weight_zeros_sizes = graph.sizes_of(weight_zeros); - VK_CHECK_COND(weight_scales_sizes == weight_zeros_sizes); -} - -void resize_linear_qta8a_qga4w_node( - ComputeGraph* graph, - const std::vector& args, - const std::vector& extra_args) { - (void)extra_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1 = args.at(1).refs.at(0); - const ValueRef mat2 = args.at(1).refs.at(1); - - const std::vector mat1_sizes = graph->sizes_of(mat1); - const std::vector mat2_sizes = graph->sizes_of(mat2); - - const int64_t out_cols = utils::val_at(-2, mat1_sizes); - const int64_t out_rows = utils::val_at(-1, mat2_sizes) * 2; - - std::vector new_out_sizes(3); - if (mat1_sizes.size() == 2) { - new_out_sizes.resize(2); - new_out_sizes.at(0) = out_cols; - new_out_sizes.at(1) = out_rows; - } else { - new_out_sizes.at(0) = mat1_sizes.at(0); - new_out_sizes.at(1) = out_cols; - new_out_sizes.at(2) = out_rows; - } - - graph->virtual_resize(out, new_out_sizes); -} - -/** - * Determines if the cooperative algorithm should be used based on input tensor - * dimensions. Apply the coop algorithm for vectors (GEMV cases), tiled for - * matrices (GEMM cases). - */ -bool should_use_coop_algorithm_qta8a_qga4w( - ComputeGraph* graph, - const ValueRef& mat1) { - const uint32_t M = graph->size_at(-2, mat1); - // Use coop algorithm for vectors (GEMV), tiled for larger matrices (GEMM) - return M == 1; -} - -vkapi::ShaderInfo pick_linear_qta8a_qga4w_shader( - ComputeGraph* graph, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - - const ValueRef out = args.at(0).refs.at(0); - const ValueRef mat1 = args.at(1).refs.at(0); - const ValueRef mat2 = args.at(1).refs.at(1); - - const bool use_coop_algorithm = - should_use_coop_algorithm_qta8a_qga4w(graph, mat1); - - std::string kernel_name = "linear_qta8a_qga4w"; - if (use_coop_algorithm) { - kernel_name += "_coop"; - } else { - kernel_name += "_tiled"; - } - add_storage_type_suffix(kernel_name, graph->storage_type_of(out)); - add_storage_type_suffix(kernel_name, graph->storage_type_of(mat1)); - add_storage_type_suffix(kernel_name, graph->storage_type_of(mat2)); - add_dtype_suffix(kernel_name, graph->dtype_of(out)); - - return VK_KERNEL_FROM_STR(kernel_name); -} - -utils::uvec3 linear_qta8a_qga4w_global_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const std::vector& args, - const std::vector& resize_args) { - (void)resize_args; - const ValueRef out = args.at(0).refs.at(0); - - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - // C = 1, H = 2, W = 3 - // global_wg_size = {round_up(C / 2f), round_up(H / 3f), W} --> (2W, 1H, 0C) - // --> {1, 1, 3} global - - utils::uvec3 global_wg_size = graph->logical_limits_of(out); - global_wg_size[0] = utils::div_up(global_wg_size[0], uint32_t(2)); - if (!use_coop_algorithm) { // GEMM - TILED - global_wg_size[1] = utils::div_up(global_wg_size[1], uint32_t(3)); - } - - return global_wg_size; -} - -utils::uvec3 linear_qta8a_qga4w_local_wg_size( - ComputeGraph* graph, - const vkapi::ShaderInfo& shader, - const utils::uvec3& global_workgroup_size, - const std::vector& args, - const std::vector& resize_args) { - (void)args; - (void)resize_args; - - const bool use_coop_algorithm = - shader.kernel_name.find("_coop") != std::string::npos; - - utils::uvec3 local_wg_size; - if (use_coop_algorithm) { // GEMV - COOP - local_wg_size = {8, 1, 8}; - } else { // GEMM - TILED - local_wg_size = graph->create_local_wg_size(global_workgroup_size); - } - - return local_wg_size; -} - -void add_linear_qta8a_qga4w_node( - ComputeGraph& graph, - const ValueRef mat1, - const ValueRef mat1_scale, - const ValueRef mat1_zero_point, - const ValueRef mat2_data, - const ValueRef group_size, - const ValueRef weight_scales_data, - const ValueRef weight_zeros_data, - const ValueRef out) { - check_linear_qta8a_qga4w_args( - graph, - mat1, - mat1_scale, - mat1_zero_point, - mat2_data, - group_size, - weight_scales_data, - weight_zeros_data, - out); - const uint32_t group_size_val = graph.extract_scalar(group_size); - - ValueRef mat2 = - prepack_int4_linear_weight_transposed_interleaved(graph, mat2_data); - ValueRef weight_scales = prepack_standard( - graph, weight_scales_data, utils::kBuffer, utils::kWidthPacked); - ValueRef weight_zeros = prepack_standard( - graph, weight_zeros_data, utils::kBuffer, utils::kWidthPacked); - - graph.execute_nodes().emplace_back(new DynamicDispatchNode( - graph, - pick_linear_qta8a_qga4w_shader, - linear_qta8a_qga4w_global_wg_size, - linear_qta8a_qga4w_local_wg_size, - // Inputs and Outputs - {{out, vkapi::kWrite}, - {{mat1, mat2, weight_scales, weight_zeros, mat1_scale, mat1_zero_point}, - vkapi::kRead}}, - // Shader params buffers - {}, - // Push Constants - {graph.sizes_pc_of(out), - graph.sizes_pc_of(mat1), - graph.sizes_pc_of(mat2)}, - // Specialization Constants - {SV(group_size_val)}, - // Resize Args - {}, - // Resizing Logic - resize_linear_qta8a_qga4w_node)); -} - -void linear_qta8a_qga4w( - ComputeGraph& graph, - const std::vector& args) { - return add_linear_qta8a_qga4w_node( - graph, - args[0], // quantized input (char tensor) - args[1], // input_scale (float buffer tensor) - args[2], // input_zero_point (int buffer tensor) - args[3], // quantized weights (4-bit packed, byte) - args[4], // group_size (int) - args[5], // weight_scales (float tensor) - args[6], // weight_zeros (int tensor) - args[7] // float output tensor - ); -} - -REGISTER_OPERATORS { - VK_REGISTER_OP(et_vk.linear_qta8a_qga4w.default, linear_qta8a_qga4w); -} - -} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp index 6cd5115563a..648d7b8da09 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.cpp @@ -327,75 +327,6 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( return qmat2; } -ValueRef prepack_int4_linear_weight_transposed_block_4x8( - ComputeGraph& graph, - const ValueRef qmat2_data) { - std::vector qmat2_orig_sizes = graph.sizes_of(qmat2_data); - const int64_t ndim = graph.dim_of(qmat2_data); - - const int64_t K_div2 = qmat2_orig_sizes.at(ndim - 1); // Input is [N, K/2] - const int64_t N = qmat2_orig_sizes.at(ndim - 2); - // Logical K dimension. Each value in the tensor is a uint8 that contains 2 - // packed 4-bit values. - const int64_t K = K_div2 * 2; - - // This packing format partitions the weight tensor into 4 wide x 8 high - // blocks. To figure out the size of the output tensor, determine the number - // of blocks along the width and height dims. - const int64_t num_blocks_K = utils::div_up(K, int64_t(4)); - const int64_t num_blocks_N = utils::div_up(N, int64_t(8)); - // Each transposed block is 8 wide x 4 high. In terms of 8-bit values, the - // block is 4 wide x 4 high. To maximize memory loading efficiency, the packed - // weight tensor will use a base data type of uint32_t; in terms of uint32_t, - // each block is 1 wide x 4 high. However, each block is also flattened as it - // is stored, so that the whole block can be loaded at once. As a result, the - // stored block will be 4 wide x 1 high. - const int64_t output_width = num_blocks_K * 4; - const int64_t output_height = num_blocks_N; - - // Store the original sizes of the tensor to pass to the shader - utils::ivec2 orig_sizes{ - utils::safe_downcast(K), utils::safe_downcast(N)}; - - std::vector qmat2_sizes{output_height, output_width}; - - utils::StorageType storage_type = utils::kTexture2D; - uint32_t max_extent = graph.context()->adapter_ptr()->max_texture2d_dim(); - if (output_width > max_extent * 4 || output_height > max_extent) { - storage_type = utils::kBuffer; - } - - ValueRef qmat2 = graph.add_tensor( - qmat2_sizes, vkcompute::vkapi::kUInt, storage_type, utils::kWidthPacked); - - // Global workgroup size: each thread writes out two adjacent blocks - utils::uvec3 global_wg_size{ - utils::div_up(utils::safe_downcast(num_blocks_K), uint32_t(2)), - utils::safe_downcast(num_blocks_N), - 1u}; - - std::string kernel_name = "pack_int4_linear_weight_transposed_block_4x8"; - add_storage_type_suffix(kernel_name, storage_type); - - graph.prepack_nodes().emplace_back(new PrepackNode( - graph, - VK_KERNEL_FROM_STR(kernel_name), - global_wg_size, - graph.create_local_wg_size(global_wg_size), - // Inputs and Outputs - qmat2_data, - qmat2, - // UBOs - {}, - // Specialization Constants - {}, - // Push Constants - {graph.sizes_pc_of(qmat2), - PushConstantDataInfo(&orig_sizes, sizeof(utils::ivec2))})); - - return qmat2; -} - void prepack_op(ComputeGraph& graph, const std::vector& args) { return add_prepack_standard_node(graph, args[0], args[1]); } diff --git a/backends/vulkan/runtime/graph/ops/impl/Staging.h b/backends/vulkan/runtime/graph/ops/impl/Staging.h index 0b1568ca139..5f5cdd1eda0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Staging.h +++ b/backends/vulkan/runtime/graph/ops/impl/Staging.h @@ -95,8 +95,4 @@ ValueRef prepack_int4_linear_weight_transposed_interleaved( ComputeGraph& graph, const ValueRef qmat2_data); -ValueRef prepack_int4_linear_weight_transposed_block_4x8( - ComputeGraph& graph, - const ValueRef qmat2_data); - } // namespace vkcompute diff --git a/backends/vulkan/test/op_tests/quantized_linear_test.cpp b/backends/vulkan/test/op_tests/quantized_linear_test.cpp index db95f4a793f..6e81820e735 100644 --- a/backends/vulkan/test/op_tests/quantized_linear_test.cpp +++ b/backends/vulkan/test/op_tests/quantized_linear_test.cpp @@ -33,43 +33,10 @@ class VulkanLinearQCS4WTest : public ::testing::Test { } }; -class VulkanLinearQTA8AQGA4WTest : public ::testing::Test { - public: - void SetUp() override { - if (!vkcompute::api::context() - ->adapter_ptr() - ->has_full_int8_buffers_support()) { - GTEST_SKIP(); - } - } - - void TearDown() override { - // Clean up any resources if needed - } -}; - // // Reference Implementations // -at::Tensor linear_qga4w_reference_impl( - const at::Tensor& x, - const at::Tensor& weights_4x2, - const int64_t groupsize, - const at::Tensor& scales_and_zeros, - const int64_t inner_k_tiles) { - const std::vector original_x_size(x.sizes().vec()); - const size_t ndim = original_x_size.size(); - const int64_t out_features = weights_4x2.size(0); - const at::Tensor x_flattened = x.reshape({-1, original_x_size[ndim - 1]}); - at::Tensor out = at::_weight_int4pack_mm_for_cpu( - x_flattened, weights_4x2, groupsize, scales_and_zeros); - std::vector out_shape( - original_x_size.begin(), original_x_size.end()); - out_shape.at(ndim - 1) = out_features; - return out.reshape(out_shape); -} - at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { std::vector weights_shape(weights_4x2.sizes().vec()); weights_shape[1] *= 2; @@ -94,41 +61,6 @@ at::Tensor unpack_weights_4x2(const at::Tensor& weights_4x2) { return weights_unpacked; } -at::Tensor dequantize_and_linear_qga4w( - const at::Tensor& x, - const at::Tensor& weights_4x2, - const int64_t groupsize, - const at::Tensor& scales_and_zeros, - const int64_t inner_k_tiles) { - std::vector weights_shape(weights_4x2.sizes().vec()); - weights_shape[1] *= 2; - - at::Tensor weights_dequantized = - at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); - - const int64_t N = weights_dequantized.size(0); - const int64_t K = weights_dequantized.size(1); - - const int k_groups = K / groupsize; - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k += 2) { - const int group_idx = k / groupsize; - // const int scale_idx = k_groups * n + group_idx; - const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); - const uint8_t second_val = packed_val & 0x0F; - const uint8_t first_val = (packed_val & 0xF0) >> 4; - - const float scale = scales_and_zeros[group_idx][n][0].item().to(); - const float zero = scales_and_zeros[group_idx][n][1].item().to(); - - weights_dequantized[n][k] = (float(first_val) - 8.0) * scale + zero; - weights_dequantized[n][k + 1] = (float(second_val) - 8.0) * scale + zero; - } - } - - return at::linear(x, weights_dequantized); -} - at::Tensor dequantize_and_linear_qcs4w( const at::Tensor& x, const at::Tensor& weights_4x2, @@ -179,197 +111,10 @@ at::Tensor linear_qcs4w_reference_impl( return out.reshape(out_shape); } -at::Tensor linear_qta8a_qga4w_quantized_matmul( - const at::Tensor& quantized_input, // [B, M, K] int8 quantized input - const at::Tensor& input_scale, // [B*M] per-token input scales - const at::Tensor& input_zero_point, // [B*M] per-token input zero points - const at::Tensor& weights_4x2, // [N, K/2] 4-bit packed weights - const int64_t group_size, // Group size for weight quantization - const at::Tensor& weight_scales, // [K/group_size, N] weight scales - const at::Tensor& weight_zeros) { // [K/group_size, N] weight zeros - - const int64_t B = quantized_input.size(0); - const int64_t M = quantized_input.size(1); - const int64_t K = quantized_input.size(2); - const int64_t N = weights_4x2.size(0); - - // Create output tensor for floating point results - at::Tensor float_output = - at::zeros({B, M, N}, at::device(at::kCPU).dtype(at::kFloat)); - - // Accessors for efficient access - auto input_accessor = quantized_input.accessor(); - auto output_accessor = float_output.accessor(); - auto weights_accessor = weights_4x2.accessor(); - auto weight_scales_accessor = weight_scales.accessor(); - auto weight_zeros_accessor = weight_zeros.accessor(); - auto input_scale_accessor = input_scale.accessor(); - auto input_zero_accessor = input_zero_point.accessor(); - - // Perform quantized matrix multiplication following quantization.md equation - // (5): result_real_value = lhs_scale * rhs_scale * Sum_over_k( - // (lhs_quantized_value[k] - lhs_zero_point) * - // (rhs_quantized_value[k] - rhs_zero_point) - // ) - for (int64_t b = 0; b < B; b++) { - for (int64_t m = 0; m < M; m++) { - const int64_t token_idx = b * M + m; - const float lhs_scale = - input_scale_accessor[token_idx]; // Per-token input scale - const int32_t lhs_zero_point = - input_zero_accessor[token_idx]; // Per-token input zero point - - for (int64_t n = 0; n < N; n++) { - float result_real_value = 0.0f; - - for (int64_t k = 0; k < K; k++) { - // Get per-group weight quantization parameters - const int64_t group_idx = k / group_size; - const float rhs_scale = - weight_scales_accessor[group_idx][n]; // Per-group weight scale - const int32_t rhs_zero_point = - weight_zeros_accessor[group_idx] - [n]; // Per-group weight zero point - - // Unpack the 4-bit weight for this position - const uint8_t packed_val = weights_accessor[n][k / 2]; - uint8_t weight_4bit; - if (k % 2 == 0) { - weight_4bit = (packed_val & 0xF0) >> 4; // First weight in pair - } else { - weight_4bit = packed_val & 0x0F; // Second weight in pair - } - - // Get quantized values - const int32_t lhs_quantized_value = - static_cast(input_accessor[b][m][k]); - // Convert 4-bit weight to signed: subtract 8 to get range [-8, 7] - const int32_t rhs_quantized_value = - static_cast(weight_4bit) - 8; - - // Apply proper quantization paradigm from quantization.md equation - // (3): real_value = scale * (quantized_value - zero_point) Following - // equation (5): result = lhs_scale * rhs_scale * - // (lhs_quantized - lhs_zero) * (rhs_quantized - rhs_zero) - const float lhs_diff = - static_cast(lhs_quantized_value - lhs_zero_point); - const float rhs_diff = - static_cast(rhs_quantized_value - rhs_zero_point); - - result_real_value += lhs_scale * rhs_scale * lhs_diff * rhs_diff; - } - - output_accessor[b][m][n] = result_real_value; - } - } - } - - return float_output; -} - -at::Tensor linear_qta8a_qga4w_4bit_dequant_impl( - const at::Tensor& quantized_input, - const at::Tensor& input_scale, - const at::Tensor& input_zero_point, - const at::Tensor& weights_4x2, - const int64_t group_size, - const at::Tensor& weight_scales, - const at::Tensor& weight_zeros) { - // Calculate number of input tokens - int64_t input_num_tokens = 1; - for (size_t i = 0; i < quantized_input.sizes().size() - 1; i++) { - input_num_tokens *= quantized_input.size(i); - } - - // Manually dequantize the char tensor using per-token quantization - at::Tensor x_float = at::zeros_like(quantized_input, at::kFloat); - - // Apply per-token dequantization - auto input_accessor = quantized_input.accessor(); - auto output_accessor = x_float.accessor(); - - for (int64_t token_idx = 0; token_idx < input_num_tokens; token_idx++) { - float scale_val = input_scale[token_idx].item(); - int zero_point_val = input_zero_point[token_idx].item(); - - // Calculate batch and sequence indices for this token - int64_t b = token_idx / quantized_input.size(1); - int64_t m = token_idx % quantized_input.size(1); - - // Apply dequantization for all features in this token - for (int64_t k = 0; k < quantized_input.size(-1); k++) { - float dequant_val = - (input_accessor[b][m][k] - zero_point_val) * scale_val; - output_accessor[b][m][k] = dequant_val; - } - } - - std::vector weights_shape(weights_4x2.sizes().vec()); - weights_shape[1] *= 2; - - at::Tensor weights_dequantized = - at::empty(weights_shape, at::device(at::kCPU).dtype(at::kFloat)); - - const int64_t N = weights_dequantized.size(0); - const int64_t K = weights_dequantized.size(1); - - for (int n = 0; n < N; n++) { - for (int k = 0; k < K; k += 2) { - const int group_idx = k / group_size; - const uint8_t packed_val = weights_4x2[n][k / 2].item().to(); - const uint8_t second_val = packed_val & 0x0F; - const uint8_t first_val = (packed_val & 0xF0) >> 4; - - const float scale = weight_scales[group_idx][n].item().to(); - const int zero = weight_zeros[group_idx][n].item().to(); - - weights_dequantized[n][k] = - ((float(first_val) - 8.0) - float(zero)) * scale; - weights_dequantized[n][k + 1] = - ((float(second_val) - 8.0) - float(zero)) * scale; - } - } - - at::Tensor linear_result = at::linear(x_float, weights_dequantized); - - return linear_result; -} - // // Test functions // -void test_reference_linear_qga4w( - const int B, - const int M, - const int K, - const int N, - const int group_size = 32, - const int inner_k_tiles = 8) { - assert(K % group_size == 0); - - at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor weights_4x2 = - at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); - at::Tensor weights_int = unpack_weights_4x2(weights_4x2); - - const int k_groups = K / group_size; - at::Tensor scales_and_zeros = - at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); - - at::Tensor out = linear_qga4w_reference_impl( - x, - at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size), - group_size, - scales_and_zeros, - inner_k_tiles); - - at::Tensor out_ref = dequantize_and_linear_qga4w( - x, weights_4x2, group_size, scales_and_zeros, inner_k_tiles); - - ASSERT_TRUE(at::allclose(out, out_ref)); -} - void test_reference_linear_qcs4w( const int B, const int M, @@ -389,118 +134,6 @@ void test_reference_linear_qcs4w( ASSERT_TRUE(at::allclose(out, out_ref)); } -void test_vulkan_linear_qga4w_impl( - const int B, - const int M, - const int K, - const int N, - const int group_size = 32, - const int inner_k_tiles = 8, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - assert(K % group_size == 0); - - at::Tensor x = at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor weights_4x2 = - at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); - - const int k_groups = K / group_size; - at::Tensor scales_and_zeros = - at::rand({k_groups, N, 2}, at::device(at::kCPU).dtype(at::kFloat)); - - at::Tensor weights_int = unpack_weights_4x2(weights_4x2); - at::Tensor out_ref = linear_qga4w_reference_impl( - x, - at::_convert_weight_to_int4pack_for_cpu(weights_int, group_size), - group_size, - scales_and_zeros, - inner_k_tiles); - - // Build Vulkan graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(utils::kTexture3D); - ComputeGraph graph(config); - -#define MAKE_TENSORREF_FOR(x) \ - ValueRef r_##x = graph.add_tensorref( \ - x.sizes().vec(), \ - from_at_scalartype(x.scalar_type()), \ - x.const_data_ptr()); - - MAKE_TENSORREF_FOR(weights_4x2); - MAKE_TENSORREF_FOR(scales_and_zeros); - - IOValueRef r_x = graph.add_input_tensor( - x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); - - const ValueRef r_out = graph.add_tensor( - out_ref.sizes().vec(), - from_at_scalartype(out_ref.scalar_type()), - out_storage); - - VK_GET_OP_FN("et_vk.linear_weight_int4.default") - (graph, - {r_x.value, - r_weights_4x2, - graph.add_scalar(group_size), - r_scales_and_zeros, - kDummyValueRef, - r_out}); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // - // Run model - // - - graph.propagate_resize(); - graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(out_ref); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - ASSERT_TRUE(at::allclose(vk_out, out_ref, 1e-4, 1e-4)); -} - -void test_vulkan_linear_qga4w( - const int B, - const int M, - const int K, - const int N, - const int group_size = 32, - const int inner_k_tiles = 8) { - test_vulkan_linear_qga4w_impl( - B, - M, - K, - N, - group_size, - inner_k_tiles, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - test_vulkan_linear_qga4w_impl( - B, - M, - K, - N, - group_size, - inner_k_tiles, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - void test_vulkan_linear_qcs4w_impl( const int B, const int M, @@ -579,210 +212,6 @@ void test_vulkan_linear_qcs4w( B, M, K, N, vkcompute::utils::kTexture3D, vkcompute::utils::kTexture3D); } -void test_vulkan_linear_qta8a_qga4w_impl( - const int B, - const int M, - const int K, - const int N, - const int group_size = 8, - const vkcompute::utils::StorageType in_storage = - vkcompute::utils::kTexture3D, - const vkcompute::utils::StorageType out_storage = - vkcompute::utils::kTexture3D) { - assert(K % group_size == 0); - - const int64_t input_num_tokens = B * M; - const int k_groups = K / group_size; - - at::Tensor input_scale = - at::rand({input_num_tokens}, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor input_zero_point = at::randint( - -10, 10, {input_num_tokens}, at::device(at::kCPU).dtype(at::kInt)); - - at::Tensor float_x = - at::rand({B, M, K}, at::device(at::kCPU).dtype(at::kFloat)); - - // Create a reference quantized tensor using per-token quantization - // Mimic per-token quantization using at::quantize_per_channel by reshaping - // [num_tokens, features] - at::Tensor float_x_reshaped = float_x.view({input_num_tokens, K}); - at::Tensor qx_ref_reshaped = at::quantize_per_channel( - float_x_reshaped, - input_scale.to(at::kDouble), - input_zero_point.to(at::kLong), - 0, // axis 0 for per-token (first dimension after reshape) - c10::ScalarType::QInt8); - - at::Tensor x = - at::int_repr(qx_ref_reshaped).view(float_x.sizes()).to(at::kChar); - - at::Tensor weights_4x2 = - at::randint(0, 256, {N, K / 2}, at::device(at::kCPU).dtype(at::kByte)); - at::Tensor weight_scales = - at::rand({k_groups, N}, at::device(at::kCPU).dtype(at::kFloat)); - at::Tensor weight_zeros = at::randint( - -128, 128, {k_groups, N}, at::device(at::kCPU).dtype(at::kInt)); - - at::Tensor out_ref = linear_qta8a_qga4w_4bit_dequant_impl( - x, - input_scale, - input_zero_point, - weights_4x2, - group_size, - weight_scales, - weight_zeros); - - // Build Vulkan graph - using namespace vkcompute; - - GraphConfig config; - config.set_storage_type_override(utils::kTexture3D); - ComputeGraph graph(config); - -#define MAKE_TENSORREF_FOR(x) \ - ValueRef r_##x = graph.add_tensorref( \ - x.sizes().vec(), \ - from_at_scalartype(x.scalar_type()), \ - x.const_data_ptr()); - - MAKE_TENSORREF_FOR(weights_4x2); - MAKE_TENSORREF_FOR(weight_scales); - MAKE_TENSORREF_FOR(weight_zeros); - - IOValueRef r_x = graph.add_input_tensor( - x.sizes().vec(), from_at_scalartype(x.scalar_type()), in_storage); - - IOValueRef r_input_scale = graph.add_input_tensor( - input_scale.sizes().vec(), - from_at_scalartype(input_scale.scalar_type()), - utils::kBuffer); - - IOValueRef r_input_zero_point = graph.add_input_tensor( - input_zero_point.sizes().vec(), - from_at_scalartype(input_zero_point.scalar_type()), - utils::kBuffer); - - const ValueRef r_out = graph.add_tensor( - out_ref.sizes().vec(), - from_at_scalartype(out_ref.scalar_type()), - out_storage); - - VK_GET_OP_FN("et_vk.linear_qta8a_qga4w.default") - (graph, - {r_x.value, - r_input_scale.value, - r_input_zero_point.value, - r_weights_4x2, - graph.add_scalar(group_size), - r_weight_scales, - r_weight_zeros, - r_out}); - - ValueRef staging_out = graph.set_output_tensor(r_out); - - graph.prepare(); - - graph.prepack(); - - // - // Run model - // - - graph.propagate_resize(); - graph.copy_into_staging(r_x.staging, x.const_data_ptr(), x.numel()); - graph.copy_into_staging( - r_input_scale.staging, input_scale.const_data_ptr(), input_scale.numel()); - graph.copy_into_staging( - r_input_zero_point.staging, - input_zero_point.const_data_ptr(), - input_zero_point.numel()); - - graph.execute(); - - at::Tensor vk_out = at::empty_like(out_ref); - graph.copy_from_staging( - staging_out, vk_out.mutable_data_ptr(), vk_out.numel()); - - // This is a reference implementation that uses the quantized - // matmul paradigm. It should follow closely with how the vulkan - // implementation works, and demonstrates reasonably close results. - at::Tensor qmm_ref = linear_qta8a_qga4w_quantized_matmul( - x, - input_scale, - input_zero_point, - weights_4x2, - group_size, - weight_scales, - weight_zeros); - - // For quantized int8 operations, allow for 1-unit differences due to rounding - bool is_close = at::allclose(vk_out, out_ref, 5e-3, 5e-3); - if (!is_close) { - std::cout << "qmm_ref: \n" << qmm_ref << std::endl; - std::cout << "out_ref: \n" << out_ref << std::endl; - std::cout << "vk_out: \n" << vk_out << std::endl; - } - - ASSERT_TRUE(is_close); -} - -void test_vulkan_linear_qta8a_qga4w( - const int B, - const int M, - const int K, - const int N, - const int group_size = 32) { - test_vulkan_linear_qta8a_qga4w_impl( - B, - M, - K, - N, - group_size, - vkcompute::utils::kBuffer, - vkcompute::utils::kBuffer); - - test_vulkan_linear_qta8a_qga4w_impl( - B, - M, - K, - N, - group_size, - vkcompute::utils::kTexture3D, - vkcompute::utils::kTexture3D); -} - -// Test linear_qga4w operator - -TEST(VulkanLinearQGA4WTest, test_reference_impl) { - test_reference_linear_qga4w( - /*B = */ 1, - /*M = */ 4, - /*K = */ 128, - /*N = */ 32); -} - -TEST(VulkanLinearQGA4WTest, test_vulkan_impl_small_m) { - test_vulkan_linear_qga4w( - /*B = */ 1, - /*M = */ 4, - /*K = */ 128, - /*N = */ 32); - - test_vulkan_linear_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 256, - /*N = */ 256); -} - -TEST(VulkanLinearQGA4WTest, test_vulkan_impl_gemm) { - test_vulkan_linear_qga4w( - /*B = */ 1, - /*M = */ 256, - /*K = */ 256, - /*N = */ 256); -} - // Test linear_qcs4w operator TEST_F(VulkanLinearQCS4WTest, test_reference_impl) { @@ -814,87 +243,3 @@ TEST_F(VulkanLinearQCS4WTest, test_vulkan_impl_gemm) { /*K = */ 32, /*N = */ 32); } - -// Test linear_qta8a_qga4w operator - -TEST_F( - VulkanLinearQTA8AQGA4WTest, - test_vulkan_linear_quant_gemm_custom_groupsize) { - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 2, - /*K = */ 8, - /*N = */ 8, - /*group_size = */ 8); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 2, - /*K = */ 16, - /*N = */ 8, - /*group_size = */ 8); -} - -TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemm) { - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 4, - /*K = */ 64, - /*N = */ 32); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 4, - /*K = */ 128, - /*N = */ 32); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 8, - /*K = */ 64, - /*N = */ 16); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 256, - /*K = */ 256, - /*N = */ 256); -} - -TEST_F( - VulkanLinearQTA8AQGA4WTest, - test_vulkan_linear_quant_gemv_custom_groupsize) { - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 8, - /*N = */ 8, - /*group_size = */ 8); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 16, - /*N = */ 8, - /*group_size = */ 8); -} - -TEST_F(VulkanLinearQTA8AQGA4WTest, test_vulkan_linear_quant_gemv) { - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 32, - /*N = */ 32); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 64, - /*N = */ 16); - - test_vulkan_linear_qta8a_qga4w( - /*B = */ 1, - /*M = */ 1, - /*K = */ 256, - /*N = */ 256); -} diff --git a/backends/vulkan/test/test_vulkan_passes.py b/backends/vulkan/test/test_vulkan_passes.py index 76e25f2e291..4a30ab6c2de 100644 --- a/backends/vulkan/test/test_vulkan_passes.py +++ b/backends/vulkan/test/test_vulkan_passes.py @@ -17,7 +17,6 @@ from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, ) -from torchao.quantization.linear_quant_modules import Int8DynActInt4WeightQuantizer from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e from torchao.quantization.pt2e.quantizer import Quantizer @@ -155,62 +154,6 @@ def test_fuse_linear_qcs4w(self): self.assertEqual(op_node_count(gm, "linear_qcs4w.default"), 1) self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - @unittest.skip( - "linear_qta8a_qga4w currently does not support E2E dynamic quantization" - ) - def test_fuse_linear_qta8a_qga4w(self): - """Test fusion of dynamic activation + grouped weight quantized linear (QTA8A_QGA4W).""" - K = 256 - N = 256 - model = SingleLinearModule(K, N) - sample_inputs = model.get_sample_inputs() - - # Use source transform quantizer for dynamic activation + grouped weight quantization - quantizer = Int8DynActInt4WeightQuantizer( - groupsize=128, # Group size for 4-bit weights - padding_allowed=False, - precision=torch.float32, - scales_precision=torch.float32, - device=torch.device("cpu"), - ) - - # Apply source transform quantization - quantized_model = quantizer.quantize(model) - - # Export the quantized model - edge_compile_config = EdgeCompileConfig( - _skip_dim_order=False, - _check_ir_validity=False, - ) - - program = torch.export.export( - quantized_model, sample_inputs, strict=True - ).module() - - program = torch.export.export(program, sample_inputs) - - edge_manager = to_edge( - program, - compile_config=edge_compile_config, - ) - - ep = edge_manager._edge_programs["forward"] - edge_manager.transform( - [ - AddmmToLinearTransform(), - FuseQuantizedOpsTransform(ep), - ] - ) - - gm = ep.graph_module - - # Check that the linear_qta8a_qga4w operator was created - self.assertEqual(op_node_count(gm, "linear_qta8a_qga4w.default"), 1) - # Check that the original quantization/dequantization nodes were removed - self.assertEqual(op_node_count(gm, "quantize_per_token.default"), 0) - self.assertEqual(op_node_count(gm, "dequantize_per_channel.default"), 0) - self.assertEqual(op_node_count(gm, "linear.default"), 0) - def test_fuse_rotary_emb(self): """Test conversion of rotary embedding pattern to et_vk.apply_rotary_emb custom op.""" diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 95761ce4211..a72dd1d9b3b 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -782,7 +782,7 @@ def get_quantizer_and_quant_params(llm_config): def _qmode_type(value): - choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w", "4w"] + choices = ["int8", "8da4w", "8da4w-gptq", "4w"] patterns = [r"torchao:8da(\d+)w", r"torchao:fpa(\d+)w"] if value in choices: diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 0278bc6e912..9f2210b5c64 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -159,13 +159,6 @@ def quantize( # noqa C901 if verbose: print("quantized model:", model) return model - elif qmode == "vulkan_4w": - from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer - - q_group_size = 256 if group_size is None else group_size - model = VkInt4WeightOnlyQuantizer(groupsize=q_group_size).quantize(model) - - return model elif qmode == "4w": from torchao.quantization.granularity import PerGroup from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_ diff --git a/extension/llm/export/config/llm_config.py b/extension/llm/export/config/llm_config.py index 31751f4011d..04991c0e73e 100644 --- a/extension/llm/export/config/llm_config.py +++ b/extension/llm/export/config/llm_config.py @@ -323,7 +323,6 @@ class QuantizationConfig: "int8", "8da4w", "8da4w-gptq", - "vulkan_4w", "4w", ] AO_QUANT_PATTERNS: ClassVar[List[str]] = [