Skip to content

Commit f32b984

Browse files
mcr229facebook-github-bot
authored andcommitted
Check group size is divisible by 32 (#6941)
Summary: Currently when checking the per_channel_group quantization parameters we don't check that the group_size must be a multiple of 32. This constraint was added after we implemented the original checks here. Let's add multiple of 32 here. Reviewed By: malfet, digantdesai Differential Revision: D66131456
1 parent 63870b0 commit f32b984

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -309,27 +309,30 @@ def _check_per_channel_group_params(
309309
num_groups = cast(torch.Tensor, quant_params.scale).shape[1]
310310
assert (
311311
quant_params.axis == 0
312-
), "For per_channel_group quant, axis must be 0, but got {axis}"
312+
), f"For per_channel_group quant, axis must be 0, but got {quant_params.axis}"
313313
assert (
314314
len(dims) == 2
315-
), "For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
315+
), f"For per_channel_group quant, expecting linear weights to be 2d, but got {len(dims)}"
316316
assert (
317317
num_groups > 0 and quant_params.group_size > 0
318-
), "For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
318+
), f"For per_channel_group quant, num_groups and group_size must be > 0, but got num_groups: {num_groups}, group_size: {quant_params.group_size}"
319319
output_channels = dims[quant_params.axis]
320320
input_channels = dims[quant_params.axis ^ 1]
321+
assert (
322+
quant_params.group_size % 32 == 0
323+
), f"Delegation to XNNPACK requires group_size to be a multiple of 32, but got {quant_params.group_size}"
321324
assert (
322325
output_channels == cast(torch.Tensor, quant_params.scale).shape[0]
323-
), "For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
326+
), f"For per_channel_group quant, expecting output channels to match scale.shape[0], gut got: {output_channels}, scale.shape[0]: {quant_params.scale.shape[0]}"
324327
assert (
325328
input_channels % num_groups == 0
326-
), "For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
329+
), f"For per_channel_group quant, expecting input channels to be divisible by num_groups, but got ic: {input_channels}, num_groups: {num_groups}"
327330
assert (
328331
input_channels % quant_params.group_size == 0
329-
), "For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
332+
), f"For per_channel_group quant, expecting input channels to be divisible by group_size, but got ic: {input_channels}, group_size: {quant_params.group_size}"
330333
assert (
331334
input_channels / quant_params.group_size == num_groups
332-
), "For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
335+
), f"For per_channel_group quant, expecting input channels // group_size == num_groups, but got ic: {input_channels}, group_size: {quant_params.group_size}, num_groups: {num_groups}"
333336

334337
# For now group quantization is only supported for 4b weights
335338
assert quant_params.is_qc4w, "Only 4b group quantization is supported"

backends/xnnpack/test/ops/linear.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,29 @@ def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
457457
self._test_groupwise_dq_linear(
458458
lin_mod, inputs, group_size=bl, use_bias=use_bias, atol=1e-2
459459
)
460+
461+
462+
@unittest.skipIf(
463+
not torchao_installed, "Per Channel Group Quantization Required TorchAO"
464+
)
465+
def test_qd8_fp32_per_token_groupwise_unsupported_groupsize(self):
466+
# groupsize must be multiple of 32
467+
lin_mod = BaseLinear(
468+
in_size=1,
469+
input_channels=60,
470+
output_channels=60,
471+
dtype=torch.float32,
472+
use_bias=True,
473+
)
474+
inputs = lin_mod.get_inputs()
475+
476+
with self.assertRaisesRegex(
477+
AssertionError,
478+
"Delegation to XNNPACK requires group_size to be a multiple of 32, but got 30",
479+
):
480+
self._test_groupwise_dq_linear(
481+
lin_mod, inputs, group_size=30, use_bias=False, atol=1e-2
482+
)
460483

461484
def _test_linear(
462485
self,

0 commit comments

Comments
 (0)