diff --git a/backends/xnnpack/test/ops/test_linear.py b/backends/xnnpack/test/ops/test_linear.py index ac6fec25732..dc92a9542a9 100644 --- a/backends/xnnpack/test/ops/test_linear.py +++ b/backends/xnnpack/test/ops/test_linear.py @@ -395,7 +395,9 @@ def _test_groupwise_dq_linear( quantize_( mod, Int8DynamicActivationIntxWeightConfig( - weight_dtype=torch.int4, weight_granularity=PerGroup(group_size) + # pyre-ignore[16] + weight_dtype=torch.int4, + weight_granularity=PerGroup(group_size), ), ) unwrap_tensor_subclass(mod) diff --git a/examples/models/llama/source_transformation/quantize.py b/examples/models/llama/source_transformation/quantize.py index 835972b7f3e..8b76b7650fe 100644 --- a/examples/models/llama/source_transformation/quantize.py +++ b/examples/models/llama/source_transformation/quantize.py @@ -135,6 +135,7 @@ def quantize( # noqa C901 PerAxis(0) if group_size == 0 else PerGroup(group_size) ), weight_mapping_type=MappingType.SYMMETRIC, + # pyre-ignore[6] intx_packing_format="opaque_torchao_auto", ), ) @@ -154,12 +155,23 @@ def quantize( # noqa C901 from torchao.quantization.granularity import PerGroup from torchao.utils import unwrap_tensor_subclass + def filter_fn(m, fqn): + is_linear = isinstance(m, nn.Linear) + has_shape_compatible_with_group_size = False + if is_linear: + has_shape_compatible_with_group_size = ( + m.weight.shape[1] % group_size == 0 + ) + return is_linear and has_shape_compatible_with_group_size + quantize_( model, Int8DynamicActivationIntxWeightConfig( + # pyre-ignore[16] weight_dtype=torch.int4, weight_granularity=PerGroup(group_size), ), + filter_fn=filter_fn, ) model = unwrap_tensor_subclass(model)