Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,38 @@ def test_linear(self):
error = compute_error(original, quantized)
self.assertTrue(error > 20)

def test_hqq_intx_weight_only_config(self):
dtype = torch.bfloat16
device = "cpu"
config = IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=PerGroup(32),
intx_choose_qparams_algorithm="hqq_scale_only",
)
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, config)
quantized = linear(input)
error = compute_error(original, quantized)
self.assertTrue(error > 20, f"Got error {error}")

def test_hqq_int8_dyn_act_intx_weight_config(self):
dtype = torch.bfloat16
device = "cpu"
config = Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(32),
intx_choose_qparams_algorithm="hqq_scale_only",
)
input = torch.randn(1, 128, dtype=dtype, device=device)
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, config)
quantized = linear(input)
error = compute_error(original, quantized)
self.assertTrue(error > 20, f"Got error {error}")

def test_slice(self):
dtype = torch.bfloat16
device = "cpu"
Expand Down
11 changes: 11 additions & 0 deletions torchao/quantization/qat/api.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, are the changes in this file tested as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple
Expand Down Expand Up @@ -232,13 +234,15 @@ def _qat_config_transform(
# Optionally pass custom scales and zero points to base config handler
# This is only for range learning and only applies to weights
kwargs = {}
has_custom_scale_and_zero_point = False
weight_config = module.weight_fake_quantizer.config
if (
isinstance(weight_config, IntxFakeQuantizeConfig)
and weight_config.range_learning
):
kwargs["custom_scale"] = module.weight_fake_quantizer.scale
kwargs["custom_zero_point"] = module.weight_fake_quantizer.zero_point
has_custom_scale_and_zero_point = True

# Swap FakeQuantizedLinear -> nn.Linear
# Swap FakeQuantizedEmbedding -> nn.Embedding
Expand All @@ -253,6 +257,13 @@ def _qat_config_transform(
f"Encountered unexpected module {module}, should never happen"
)
if base_config is not None:
# If passing custom scales and zero points, we need to disable the choose_qparam_algorithm on the config
if has_custom_scale_and_zero_point and hasattr(
base_config, "intx_choose_qparams_algorithm"
):
logging.debug("Disabling intx_choose_qparams_algorithm")
base_config = copy.deepcopy(base_config)
base_config.intx_choose_qparams_algorithm = None
return _QUANTIZE_CONFIG_HANDLER[type(base_config)](
module, base_config, **kwargs
)
Expand Down
22 changes: 21 additions & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
IntxChooseQParamsAlgorithm,
IntxOpaqueTensor,
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
Expand Down Expand Up @@ -748,6 +749,7 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
`intx_packing_format`: The format to use for the packed weight tensor (version 2 only).
- unpacked_to_int8: this format is the default and is intended for export applications like ExecuTorch.
- opaque_torchao_auto: this format is optimized for CPU performance.
`intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters.
`version`: version of the config to use, only subset of above args are valid based on version, see note for more details.

Note:
Expand All @@ -766,6 +768,9 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
act_mapping_type: MappingType = MappingType.ASYMMETRIC
layout: Layout = QDQLayout()
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = (
IntxChooseQParamsAlgorithm.AFFINE
)

version: int = 2

Expand Down Expand Up @@ -830,6 +835,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(
act_mapping_type = config.act_mapping_type
layout = config.layout
intx_packing_format = config.intx_packing_format
intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm

assert weight.dim() == 2, (
f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}"
Expand Down Expand Up @@ -868,6 +874,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(
weight_dtype,
mapping_type=weight_mapping_type,
activation_quantization="int8_asym_per_token",
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
)
Expand All @@ -889,6 +896,9 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(

# Version 1
assert config.version == 1
assert intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE, (
"IntxChooseQParamsAlgorithm.AFFINE is the only supported algorithm for version 1"
)
warnings.warn(
"Config Deprecation: version 1 of Int8DynamicActivationIntxWeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details"
)
Expand Down Expand Up @@ -2169,6 +2179,7 @@ class IntxWeightOnlyConfig(AOBaseConfig):
- QDQLayout: this layout is designed for export to ExecuTorch.this layout represents the quantization with Q/DQ quant primitives,
and is intended for export applications like ExecuTorch.
`intx_packing_format`: The format to use for the packed weight tensor (version 2 only).
`intx_choose_qparams_algorithm`: The algorithm to use for choosing the quantization parameters.
`version`: version of the config to use, only subset of above args are valid based on version, see note for more details.

Note:
Expand All @@ -2185,6 +2196,9 @@ class IntxWeightOnlyConfig(AOBaseConfig):
scale_dtype: Optional[torch.dtype] = None
layout: Layout = QDQLayout()
intx_packing_format: IntxPackingFormat = IntxPackingFormat.UNPACKED_TO_INT8
intx_choose_qparams_algorithm: IntxChooseQParamsAlgorithm = (
IntxChooseQParamsAlgorithm.AFFINE
)
version: int = 2

def __post_init__(self):
Expand All @@ -2202,8 +2216,9 @@ def __post_init__(self):
assert self.mapping_type in [
MappingType.ASYMMETRIC,
MappingType.SYMMETRIC,
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
], (
f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}"
f"mapping_type must be MappingType.ASYMMETRIC, MappingType.SYMMETRIC, or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.mapping_type}"
)


Expand All @@ -2220,6 +2235,7 @@ def _intx_weight_only_quantize_tensor(
scale_dtype = config.scale_dtype
layout = config.layout
intx_packing_format = config.intx_packing_format
intx_choose_qparams_algorithm = config.intx_choose_qparams_algorithm

assert weight.dim() == 2, (
f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}"
Expand Down Expand Up @@ -2247,6 +2263,7 @@ def _intx_weight_only_quantize_tensor(
mapping_type=mapping_type,
custom_scale=custom_scale,
custom_zero_point=custom_zero_point,
intx_choose_qparams_algorithm=intx_choose_qparams_algorithm,
)
if scale_dtype is not None and scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
Expand All @@ -2258,6 +2275,9 @@ def _intx_weight_only_quantize_tensor(
raise ValueError(f"Unsupported packing format: {intx_packing_format}")

# Version 1
assert config.intx_choose_qparams_algorithm == IntxChooseQParamsAlgorithm.AFFINE, (
"version 1 only supports affine algorithm"
)
assert config.version == 1
warnings.warn(
"Config Deprecation: version 1 of IntxWeightOnlyConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2967 for more details"
Expand Down
94 changes: 94 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"_choose_qparams_affine_dont_preserve_zero",
"_choose_qparams_affine_floatx",
"_choose_qparams_and_quantize_affine_hqq",
"_choose_qparams_and_quantize_scale_only_hqq",
"_choose_qparams_and_quantize_affine_qqq",
"_choose_scale_float8",
"_choose_qparams_gguf",
Expand Down Expand Up @@ -2125,6 +2126,99 @@ def _choose_qparams_and_quantize_affine_hqq(
return W_q, scale, zero, shape


@torch.no_grad()
def _choose_qparams_and_quantize_scale_only_hqq(
hp_tensor: torch.Tensor,
block_size: List[int],
qmin: int,
qmax: int,
*,
iters: int = 20,
stochastic: bool = False,
early_stop_tol: float = 1e-5,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Half-Quadratic Quantization (scale-only, symmetric) for 2D weights with row-wise blocks.
- hp_tensor: [out, in] (bf16/fp16/fp32 accepted; promoted to fp32 internally)
- block_size: must be [1, group_size]; groups along the last dim
- qmin, qmax: integer range (e.g., -8, 7 for signed 4-bit)
Returns:
qdata: int32, same shape as hp_tensor
scale: hp_tensor.dtype, shape [out, in // group_size] (one scale per row-wise block)
"""
# --- strict interface guarantees ---
assert hp_tensor.ndim == 2, "hp_tensor must be 2D [out, in]"
assert isinstance(block_size, (list, tuple)) and len(block_size) == 2, (
"block_size must be a 2-element list/tuple"
)
assert block_size[0] == 1 and block_size[1] >= 1, (
"block_size must be [1, group_size] with group_size >= 1"
)
assert qmin < qmax, "qmin must be < qmax"

# Promote to fp32 for stable math
compute_dtype = torch.float32
compute_eps = torch.finfo(compute_dtype).eps

n, k = hp_tensor.shape
group_size = int(block_size[1])
assert k % group_size == 0, (
f"in_features={k} must be divisible by group_size={group_size}"
)

def round_det(x: torch.Tensor) -> torch.Tensor:
# ties-to-even; fine for PTQ
return x.round()

def round_stoch(x: torch.Tensor) -> torch.Tensor:
# unbiased stochastic rounding
return torch.floor(x + torch.rand_like(x))

_r = round_stoch if stochastic else round_det

# Reshape Wg into [n, n_groups, group_size]
W = hp_tensor.to(compute_dtype).contiguous()
n_groups = k // group_size
Wg = W.view(n, n_groups, group_size)

# Initialize per-block scales as max-abs / qabs
# scale.shape = [n, n_groups]
qabs = max(abs(qmin), abs(qmax)) or 1
scale = (Wg.abs().amax(dim=2) / qabs).clamp_min(compute_eps)
prev_scale = scale.clone()

# Iterate HQQ updates
for _ in range(max(1, iters)):
# Quantize using current scale
# Qg.shape = [n, n_groups, group_size]
Qg = _r(Wg / scale.unsqueeze(-1)).clamp(qmin, qmax)

# Solve least-square problem min_{s} ||Wg - s * Qg||^2 and project
# solution onto positive space, or take previous value
num = (Wg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups]
den = (Qg * Qg).sum(dim=2, dtype=torch.float32) # [n, n_groups]
scale = torch.where(den > 0, num / den, prev_scale)
scale = scale.clamp_min(
compute_eps
).abs() # project LS solution onto [eps, inf]

rel = ((scale - prev_scale).abs() / prev_scale.clamp_min(compute_eps)).max()
if rel < early_stop_tol:
break
prev_scale = scale

# Quantize using final scale
Qg = _r(Wg / scale.unsqueeze(-1)).clamp(qmin, qmax)

# Restore shapes
qdata = Qg.view(n, k).contiguous().to(torch.int32)

out_dtype = hp_tensor.dtype
scale = scale.to(out_dtype)

return qdata, scale


def _choose_qparams_affine_floatx(
tensor: torch.Tensor, ebits: int, mbits: int
) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Int4Tensor,
)
from .int4.int4_tile_packed_to_4d_tensor import Int4TilePackedTo4dTensor
from .intx.intx_choose_qparams_algorithm import IntxChooseQParamsAlgorithm
from .intx.intx_opaque_tensor import (
IntxOpaqueTensor,
)
Expand All @@ -41,6 +42,7 @@
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
"IntxChooseQParamsAlgorithm",
"IntxPackingFormat",
"IntxUnpackedToInt8Tensor",
"IntxOpaqueTensor",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum


# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum)
# after python 3.10 is end of life (https://devguide.python.org/versions/)
class IntxChooseQParamsAlgorithm(str, Enum):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, do we need to introduce yet another class?

Can we just extend existing Int4ChooseQParamsAlgorithm and add affine and hqq_scale_only?

And then rename/promote Int4ChooseQParamsAlgorithm to IntxChooseQParamsAlgorithm in a follow-up PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In torchao's refactor (removing AffineQuantizedTensor), the direction is for subclasses to not share higher-level abstractions, but instead define their own enums. This is how packing format works as well (intx_packing_format for the intx subclass, and int4_packing_format for the int4 subclass).

I'll let @jerryzh168 comment here as well

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will defer to @jerryzh168 then

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we want local abstractions instead of global abstractions unless it's required.

"""Variant of quantization algorithm to calculate scale and zero_point"""

"""
Uses `torchao.quantization.quant_primitives.choose_qparams_affine`
"""
AFFINE = "affine"

"""
Uses `torchao.quantization.quant_primitives._choose_qparams_and_quantize_scale_only_hqq`
"""
HQQ_SCALE_ONLY = "hqq_scale_only"
Loading
Loading