From 3490c1e82debb7ea17bd3fb3a9431535fe53f9d6 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Thu, 11 Sep 2025 14:32:43 -0700 Subject: [PATCH] [ExecuTorch] XNNPACK: prefer qc over qb when gs == k for non-int4 Pull Request resolved: https://github.com/pytorch/executorch/pull/14173 * Prefer chanelwise over groupwise when possible for perf and for int8 which doesn't have groupwise support * Fix bug / improve behavior for affine q/dq with gs == k for per_channel * refactor is_per_channel_group state variable * add QuantParams.__str__() TODO - improve affine quant primitives - T237476295 ghstack-source-id: 309177704 @exported-using-ghexport Differential Revision: [D82060758](https://our.internmc.facebook.com/intern/diff/D82060758/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D82060758/)! --- backends/xnnpack/operators/node_visitor.py | 8 +-- backends/xnnpack/operators/quant_params.py | 63 ++++++++++++++++++++-- 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/backends/xnnpack/operators/node_visitor.py b/backends/xnnpack/operators/node_visitor.py index 6a055c9413f..68226644859 100644 --- a/backends/xnnpack/operators/node_visitor.py +++ b/backends/xnnpack/operators/node_visitor.py @@ -232,7 +232,7 @@ def get_per_channel_dtype( if quant_params.dtype == torch.int32: return XNNDatatype.xnn_datatype_qcint32 elif quant_params.dtype == torch.int8: - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: # 4-bit per channel group quantized weights # No 8-bit support yet assert ( @@ -282,7 +282,7 @@ def get_quant_params( buffer_idx = len(xnn_graph.constant_data) num_scales = scale.numel() - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: scale = scale.to(torch.bfloat16) num_bytes = scale.untyped_storage().nbytes() @@ -300,7 +300,7 @@ def get_quant_params( scale_name, bytes(scale_array), CONSTANT_TENSOR_ALIGNMENT ) - if quant_params.is_per_channel_group: + if quant_params.per_channel_group: return PerChannelGroupQuant( scale=[], channel_dim=quant_params.axis, @@ -335,7 +335,7 @@ def _check_per_channel_group_params( ) -> None: # Make sure things are lining up for per_channel_group quantization case # Has to be done this late because we don't have clean access to the actual tensor - assert quant_params.is_per_channel_group, "Not per_channel_group quantization" + assert quant_params.per_channel_group, "Not per_channel_group quantization" # linear weights will be in [oc, ic]. And per_channel quantization must be on axis 0 num_groups = cast(torch.Tensor, quant_params.scale).shape[1] assert ( diff --git a/backends/xnnpack/operators/quant_params.py b/backends/xnnpack/operators/quant_params.py index 88a1f660f0e..f1c87c0b8b6 100644 --- a/backends/xnnpack/operators/quant_params.py +++ b/backends/xnnpack/operators/quant_params.py @@ -89,6 +89,9 @@ def __init__( # Groupwise quantization for weight self.per_channel_group = False self.group_size = group_size + + tensor = q_input.meta["val"] + if self.group_size > 0: assert ( self.per_channel is True @@ -96,12 +99,29 @@ def __init__( assert ( cast(torch.Tensor, scale).ndim == 2 ), "Scale must be 2D for per channel groupwise quant" - self.per_channel_group = True - assert group_size > 0, "Group size must be greater than 0" - self.is_per_channel_group = self.per_channel and self.group_size > 0 + # Assumed scale shape - [out_channels, in_channels/group_size] + input_channels = cast(torch.Tensor, scale).shape[1] * self.group_size + # 2d weight tensor shape - [out_channels, in_channels] + assert ( + tensor.shape[1] == input_channels + ), "Invalid input channels for groupwise quant" + # Prefer per_channel over per_channel_group when group_size == input_channels for non int4 cases only + # int4 case need more fixes to map qb4w to qc4w. Incorrect scales being passed down to xnnpack. + self.per_channel_group = ( + self.group_size <= input_channels + if self.is_qc4w + else self.group_size < input_channels + ) + + if not self.per_channel_group: + if cast(torch.Tensor, scale).ndim == 2: + # TODO: don't reshape scale for per_channel cases + assert ( + cast(torch.Tensor, scale).shape[1] == 1 + ), "Invalid scale shape for per channel quantization" + scale = cast(torch.Tensor, scale).squeeze(1) - if per_channel and not self.is_per_channel_group: - tensor = q_input.meta["val"] + if per_channel and not self.per_channel_group: assert ( tensor.shape[self.axis] == cast(torch.Tensor, self.scale).shape[0] ), f"Invalid size of per channel quantization scales, axis: {self.axis}, scale size: {self.scale.shape}, tensor shape: {tensor.shape}" @@ -110,6 +130,39 @@ def __init__( tensor.shape[self.axis] == cast(torch.Tensor, self.zp).shape[0] ), f"Invalid size of per channel quantization zero-points, axis: {self.axis}, zp size: {self.zp.shape}, tensor shape: {tensor.shape}" + def __str__(self) -> str: + """String representation of QuantParams for debugging and logging.""" + assert isinstance(self.scale, float) or isinstance(self.scale, torch.Tensor) + scale_str = ( + f"{self.scale}" + if isinstance(self.scale, float) + else f"tensor{tuple(self.scale.shape)}" + ) + assert isinstance(self.zp, float) or isinstance(self.zp, torch.Tensor) + zp_str = ( + f"{self.zp}" + if isinstance(self.zp, float) + else f"tensor{tuple(self.zp.shape)}" + ) + + return ( + f"QuantParams(" + f"per_channel={self.per_channel}, " + f"per_channel_group={self.per_channel_group}, " + f"scale={scale_str}, " + f"zp={zp_str}, " + f"axis={self.axis}, " + f"dtype={self.dtype}, " + f"qmin={self.qmin}, " + f"qmax={self.qmax}, " + f"is_dynamic={self.is_dynamic}, " + f"is_input={self.is_input}, " + f"is_output={self.is_output}, " + f"group_size={self.group_size}, " + f"is_qc4w={self.is_qc4w}" + f")" + ) + def quantize_tensor(self, tensor: torch.Tensor) -> torch.Tensor: # Do nothing if already quantized by the Quantizer if tensor.dtype == self.dtype: