From ec07cbaa8eeea8172c7d7c6b3e0b9f48b0b06d2d Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 12 Sep 2025 09:47:31 -0700 Subject: [PATCH] Forward fix for D82242003 (#14241) Summary: This fixes internal failures on D82242003: * pyre errors * buck build --flagfile fbcode//mode/dev fbcode//executorch/examples/models/fb/llama4:ngtts_semantic_lm_xnnpack_quantized.pte The second failure is because the old and new APIs have different behaviors when group_size is incompatible with the nn.Linear module's shape. In the old API, it silently does not quantize the layer, whereas the new API is more explicit and throws an error. This diff uses a filter_fn to restore the previous behavior. Reviewed By: digantdesai Differential Revision: D82265586 --- backends/xnnpack/test/ops/test_linear.py | 4 +++- .../models/llama/source_transformation/quantize.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index ac6fec25732..dc92a9542a9 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -395,7 +395,9 @@ def _test_groupwise_dq_linear( quantize_( mod, Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, weight_granularity=PerGroup(group_size) + # pyre-ignore[16] + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size), ), ) unwrap_tensor_subclass(mod) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 835972b7f3e..8b76b7650fe 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -135,6 +135,7 @@ def quantize( # noqa C901 PerAxis(0) if group_size == 0 else PerGroup(group_size) ), weight_mapping_type=MappingType.SYMMETRIC, + # pyre-ignore[6] intx_packing_format="opaque_torchao_auto", ), ) @@ -154,12 +155,23 @@ def quantize( # noqa C901 from torchao.quantization.granularity import PerGroup from torchao.utils import unwrap_tensor_subclass + def filter_fn(m, fqn): + is_linear = isinstance(m, nn.Linear) + has_shape_compatible_with_group_size = False + if is_linear: + has_shape_compatible_with_group_size = ( + m.weight.shape[1] % group_size == 0 + ) + return is_linear and has_shape_compatible_with_group_size + quantize_( model, Int8DynamicActivationIntxWeightConfig( + # pyre-ignore[16] weight_dtype=torch.int4, weight_granularity=PerGroup(group_size), ), + filter_fn=filter_fn, ) model = unwrap_tensor_subclass(model)