Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/scripts/torchao_model_releases/quantize_and_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def _untie_weights_and_save_locally(model_id):

_int4_quant_code = """
from torchao.quantization import Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2)
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_id)
Expand Down Expand Up @@ -627,7 +627,12 @@ def quantize_and_upload(
)
quant_to_config = {
"FP8": Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()),
"INT4": Int4WeightOnlyConfig(group_size=128, version=2),
"INT4": Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
),
"INT8-INT4": ModuleFqnToConfig(
{
"_default": _int8_int4_linear_config,
Expand Down
6 changes: 6 additions & 0 deletions test/core/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@
Int4WeightOnlyConfig(
group_size=32,
),
Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
),
Int8DynamicActivationInt4WeightConfig(
group_size=64,
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
)

from torchao.quantization import Int4WeightOnlyConfig, quantize_
from torchao.quantization.quantize_.common.packing_format import PackingFormat
from torchao.quantization.quantize_.workflows.int4.int4_tile_packed_to_4d_tensor import (
Int4TilePackedTo4dTensor,
)
Expand All @@ -25,7 +24,14 @@

INT4_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format=PackingFormat.TILE_PACKED_TO_4D,
packing_format="tile_packed_to_4d",
version=2,
)

INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
group_size=128,
packing_format="tile_packed_to_4d",
int4_choose_qparams_algorithm="hqq",
version=2,
)

Expand All @@ -44,8 +50,8 @@ def setUp(self):
((2, 32, 128), 256, 128),
],
)
def test_linear(self, sizes):
config = INT4_CONFIG
@parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG])
def test_linear(self, sizes, config):
dtype = torch.bfloat16
device = "cuda"

Expand All @@ -62,8 +68,8 @@ def test_linear(self, sizes):
quantized_and_compiled = compiled_linear(input)
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

def test_module_path(self):
config = INT4_CONFIG
@parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG])
def test_module_path(self, config):
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
quantize_(linear.cuda(), config)
self.assertEqual(
Expand All @@ -80,11 +86,11 @@ def test_module_path(self):
"<class 'torchao.quantization.Int4TilePackedTo4dTensor'>",
)

def test_slice(self):
@parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG])
def test_slice(self, config):
"""Note: we use multiples of 1024 for both in_features and out_features
so that padding does not affect the weight after slicing
"""
config = INT4_CONFIG
dtype = torch.bfloat16
device = "cuda"

Expand Down Expand Up @@ -169,8 +175,8 @@ def test_slice(self):
res2 = test_linear2(input2)
self.assertGreater(compute_error(res_ref2, res2), 14)

def test_slice_preserves_aliasing(self):
config = INT4_CONFIG
@parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG])
def test_slice_preserves_aliasing(self, config):
l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
l.weight = torch.nn.Parameter(
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -212,8 +218,9 @@ def test_to_device(self):
quantize_(linear, config)
linear.to(device)

def test_slice_and_copy_similar_to_vllm(self):
self._test_slice_and_copy_similar_to_vllm(INT4_CONFIG)
@parametrize("config", [INT4_CONFIG, INT4_HQQ_CONFIG])
def test_slice_and_copy_similar_to_vllm(self, config):
self._test_slice_and_copy_similar_to_vllm(config)

@parametrize("device", ["cuda"])
@parametrize("dtype", [torch.bfloat16])
Expand Down
1 change: 1 addition & 0 deletions torchao/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]:
"torchao.dtypes",
"torchao.prototype.awq",
"torchao.quantization.quantize_.common",
"torchao.quantization.quantize_.workflows",
}


Expand Down
50 changes: 32 additions & 18 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
)
from torchao.quantization.quantize_.workflows import (
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Int4OpaqueTensor,
Int4PreshuffledTensor,
Expand Down Expand Up @@ -1054,27 +1055,29 @@ def _gemlite_uintx_weight_only_transform(
@dataclass
class Int4WeightOnlyConfig(AOBaseConfig):
"""
Configuration for applying uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel

Note:
This is targeting `tinygemm` int4mm kernel (`torch.ops.aten._weight_int4pack_mm`
and `torch.ops.aten._weight_int4pack_mm_for_cpu`), the main difference
of quantization algorithm compared to the more traditional type of integer quantization is the following:
1). zero_point is in floating point domain instead of integer domain (`zero_point_domain`=`ZeroPointDomain.FLOAT`)
2). floating point zero does not have to be exactly representable (`preserve_zero`=False in `choose_qparams_affine`)
please follow the relevant code in `choose_qparams_affine`, `quantize_affine` and `dequantize_affine`
to learn about how the quantization parameters are chosen and how the Tensor is quantized/dequantized for tinygemm
Configuration for int4 weight only quantization, only groupwise quantization is supported
right now, and we support version 1 and version 2, that are implemented differently although with
same support. In version 2, different target are mainly distinguished by `packing_format` arg, and in version 1, mainly by `layout`.

Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32]
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
`use_hqq`: whether to use hqq or default quantization mode, default is False
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE]
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT
`packing_format`: the packing format for int4 tensor, available from version 2 and above
size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2
`packing_format`: the packing format for int4 tensor, used in version 2 only
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4,
currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only
`use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only
`version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 1, see note for more details

Note:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @mergennachin I added some more docs here, please let me know if this helps, only a subset of args will be used for each of the version right now

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation is good, but why not also do assertion?

If ignored field are present, don't you wanna throw an exception to the developer?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, talked offline with @jerryzh168

The current approach seems fine. The increased complexity seems not worth it since we have to change the type to Optional.

Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2

For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config
less confusing
"""

group_size: int = 128
Expand All @@ -1085,6 +1088,9 @@ class Int4WeightOnlyConfig(AOBaseConfig):
preserve_zero: Optional[bool] = None
# only used in version >= 2
packing_format: PackingFormat = PackingFormat.PLAIN
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
Int4ChooseQParamsAlgorithm.TINYGEMM
)
version: int = 1

def __post_init__(self):
Expand All @@ -1105,6 +1111,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
group_size = config.group_size
layout = config.layout
use_hqq = config.use_hqq
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
zero_point_domain = config.zero_point_domain
packing_format = config.packing_format

Expand All @@ -1118,6 +1125,12 @@ def _int4_weight_only_quantize_tensor(weight, config):

if config.version == 2:
block_size = list(block_size)

if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert packing_format == PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly"
)

if packing_format == PackingFormat.PRESHUFFLED:
new_weight = Int4PreshuffledTensor.from_hp(
weight,
Expand Down Expand Up @@ -1147,6 +1160,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
new_weight = Int4TilePackedTo4dTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight
else:
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
from .int4.int4_marlin_sparse_tensor import (
Int4MarlinSparseTensor,
)
Expand Down Expand Up @@ -33,4 +34,5 @@
"Int4OpaqueTensor",
"IntxUnpackedTensor",
"IntxUnpackedToInt8Tensor",
"Int4ChooseQParamsAlgorithm",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from enum import Enum


# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum)
# after python 3.10 is end of life (https://devguide.python.org/versions/)
class Int4ChooseQParamsAlgorithm(str, Enum):
"""Variant of quantization algorithm to calculate scale and zero_point"""

"""
The choose qparams algorithm native for tinygemm kernel:
scale = (max_val - min_val) / float(quant_max - quant_min), where
max_val and min_val are the max/min for the slice of input Tensor based on block_size
quant_max and quant_min and max/min for the quantized value, e.g. 0, 15 for uint4
zero_point = min_val + scale * mid_point, where
mid_point = (quant_max + quant_min + 1) / 2

implemented in `torchao.quantization.quant_primitives._choose_qparams_affine_tinygemm
"""
TINYGEMM = "tinygemm"

"""
The choose qparams based on half-quadratic quantization: https://mobiusml.github.io/hqq_blog/

implemented in `torchao.quantization.quant_primitives._choose_qparams_and_quantize_affine_hqq`
"""
HQQ = "hqq"
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,22 @@
# LICENSE file in the root directory of this source tree.


import math
from typing import List

import torch

from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine_tinygemm,
_choose_qparams_and_quantize_affine_hqq,
_quantize_affine_tinygemm,
)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
from torchao.utils import TorchAOBaseTensor, fill_defaults, find_multiple

from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm

__all__ = [
"Int4TilePackedTo4dTensor",
]
Expand Down Expand Up @@ -76,6 +86,7 @@ def from_hp(
cls,
hp_tensor: torch.Tensor,
block_size: List[int],
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM,
):
assert len(block_size) == hp_tensor.ndim, (
f"Expecting the length of block_size to be equal to the dimension of the weight, got {block_size=} and {hp_tensor.ndim=}"
Expand Down Expand Up @@ -115,34 +126,60 @@ def from_hp(
quant_min = 0
quant_max = 15

from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine_tinygemm,
_quantize_affine_tinygemm,
)

# Calculate scale and zero_point for tinygemm
scale, zero_point = _choose_qparams_affine_tinygemm(
hp_tensor_padded,
mapping_type=MappingType.ASYMMETRIC,
block_size=tuple(block_size),
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=hp_tensor.dtype,
)
# we support two paths for constructing a Int4TilePackedTo4dTensor
# 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute
# scale and zero_point, then convert to the format that's compatible with tinygemm kernels
# 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point
#
# both approach should have the same speed since both are using tinygemm kernel for gemm
# 1. typically will have higher accuracy compared to 2.
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
nbits = int(math.log2(quant_max + 1))
axis = 1
group_size = block_size[-1]
compute_dtype = hp_tensor_padded.dtype
device = hp_tensor_padded.device
int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
hp_tensor_padded,
nbits=nbits,
group_size=group_size,
axis=axis,
compute_dtype=compute_dtype,
device=device,
verbose=False,
raw_output=False,
# raw_output=False is basically the 'convert to tinygemm zero_point version' option (add scale*midpoint) that's used in TilePackedTo4d
# note _choose_qparams_affine_tinygemm does this same thing
)
int_data = int_data.to(target_dtype)
else:
assert (
int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM
), (
f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}"
)
# Calculate scale and zero_point for tinygemm
scale, zero_point = _choose_qparams_affine_tinygemm(
hp_tensor_padded,
mapping_type=MappingType.ASYMMETRIC,
block_size=tuple(block_size),
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=hp_tensor.dtype,
)

# Quantize for tinygemm
int_data = _quantize_affine_tinygemm(
hp_tensor_padded,
block_size,
scale,
zero_point,
target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)
# Quantize for tinygemm
int_data = _quantize_affine_tinygemm(
hp_tensor_padded,
block_size,
scale,
zero_point,
target_dtype,
quant_min=quant_min,
quant_max=quant_max,
)

# Convert to packed format
def quant_2d(int_data_2d):
Expand Down Expand Up @@ -175,8 +212,6 @@ def quant_2d(int_data_2d):
else None
)

from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)

return cls(
Expand Down
Loading