Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@
)


def get_config(group_size):
def get_config(group_size, use_hqq):
return Int4WeightOnlyConfig(
group_size=group_size,
int4_packing_format="opaque",
int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm",
)


Expand All @@ -45,13 +46,14 @@ class TestInt4OpaqueTensor(TestCase):
)
@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@parametrize("group_size", [32, 64, 128])
def test_linear(self, sizes, dtype, group_size):
@parametrize("use_hqq", [True, False])
def test_linear(self, sizes, dtype, group_size, use_hqq):
device = "cpu"
M, N, K = sizes
input = torch.randn(*M, K, dtype=dtype, device=device)
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
original = linear(input)
quantize_(linear, get_config(group_size))
quantize_(linear, get_config(group_size, use_hqq))
quantized = linear(input)
self.assertTrue(compute_error(original, quantized) > 20)

Expand All @@ -60,9 +62,10 @@ def test_linear(self, sizes, dtype, group_size):
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)

@parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
def test_module_path(self, dtype):
@parametrize("use_hqq", [True, False])
def test_module_path(self, dtype, use_hqq):
linear = torch.nn.Linear(128, 256, dtype=dtype)
quantize_(linear, get_config(group_size=128))
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
self.assertEqual(
str(type(linear.weight)),
"<class 'torchao.quantization.Int4OpaqueTensor'>",
Expand All @@ -77,12 +80,13 @@ def test_module_path(self, dtype):
"<class 'torchao.quantization.Int4OpaqueTensor'>",
)

def test_activation_prescaling(self):
@parametrize("use_hqq", [True, False])
def test_activation_prescaling(self, use_hqq):
dtype = torch.bfloat16
input = torch.randn(1, 128, dtype=dtype)
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype)
original_output = linear(input)
quantize_(linear, get_config(group_size=128))
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
qw = linear.weight
assert isinstance(qw, SupportsActivationPreScaling), (
"Expected int4 tensor supports activation prescaling"
Expand Down
13 changes: 9 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,15 +1085,15 @@ class Int4WeightOnlyConfig(AOBaseConfig):
Args:
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained, choices are [256, 128, 64, 32], used in both version 1 and 2
`packing_format`: the packing format for int4 tensor, used in version 2 only
`int4_packing_format`: the packing format for int4 tensor, used in version 2 only
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4,
currently support TINYGEMM ("tinygemm") and HQQ ("hqq"), used in version 2 only
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`, used in version 1 only
`use_hqq`: whether to use hqq or default quantization mode, default is False, used in version 1 only
`zero_point_domain`: data type of zeros points, choices are [ZeroPointDomain.FLOAT, ZeroPointDomain.INT, ZeroPointDomain.NONE], used in version 1 only
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values. used in both version 1 and 2
`preserve_zero`: whether to preserve zero, default is None. Will be set to True if zero_point_domain is ZeroPointDomain.INT, used in version 1 only
`version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 1, see note for more details
`version`: version of the config to use, only subset of above args are valid for version 1, and subset of above args are valid for version 2, default is 2, see note for more details

Note:
Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2
Expand Down Expand Up @@ -1150,8 +1150,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
block_size = list(block_size)

if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly"
assert int4_packing_format in [
Int4PackingFormat.TILE_PACKED_TO_4D,
Int4PackingFormat.OPAQUE,
], (
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently"
)

if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
Expand Down Expand Up @@ -1183,6 +1187,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
new_weight = Int4OpaqueTensor.from_hp(
weight,
block_size,
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
)
return new_weight
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
Expand Down
75 changes: 54 additions & 21 deletions torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,24 @@
# LICENSE file in the root directory of this source tree.


import math
from typing import List, Optional

import torch

from torchao.quantization.quant_primitives import (
MappingType,
_choose_qparams_affine_tinygemm,
_choose_qparams_and_quantize_affine_hqq,
_quantize_affine_tinygemm,
)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
from torchao.utils import (
TorchAOBaseTensor,
)

from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm

__all__ = [
"Int4OpaqueTensor",
]
Expand Down Expand Up @@ -95,6 +100,7 @@ def from_hp(
cls,
w: torch.Tensor,
block_size: List[int],
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = Int4ChooseQParamsAlgorithm.TINYGEMM,
):
assert w.ndim == 2 and w.device.type == "cpu", (
f"Expecting 2D tensor on CPU, but got: {w.shape} on {w.device.type}"
Expand All @@ -111,26 +117,54 @@ def from_hp(
eps = 1e-6
scale_dtype = None
zero_point_dtype = w.dtype
scale, zero_point = _choose_qparams_affine_tinygemm(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = _quantize_affine_tinygemm(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)

# we support two paths for constructing a Int4OpaqueTensor
# 1. use [hqq](https://mobiusml.github.io/hqq_blog/) algorithm to compute
# scale and zero_point, then convert to the format that's compatible with tinygemm kernels
# 2. don't use hqq, use default tinygemm algorithm to compute scale and zero_point
#
# both approach should have the same performance since both are using CPU tinygemm kernel for gemm
# 1. typically will have higher accuracy compared to 2.
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
nbits = int(math.log2(quant_max + 1))
axis = 1
group_size = block_size[-1]
int_data, scale, zero_point, _ = _choose_qparams_and_quantize_affine_hqq(
w,
nbits=nbits,
group_size=group_size,
axis=axis,
compute_dtype=zero_point_dtype,
device=w.device,
)
int_data = int_data.to(target_dtype)
else:
assert (
int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.TINYGEMM
), (
f"Unsupported Int4ChooseQParamsAlgorithm: {int4_choose_qparams_algorithm}"
)

scale, zero_point = _choose_qparams_affine_tinygemm(
w,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
)
int_data = _quantize_affine_tinygemm(
w,
block_size,
scale,
zero_point,
target_dtype,
quant_min,
quant_max,
)
assert int_data.dtype == torch.int32, (
"torch.ops.aten._convert_weight_to_int4pack_for_cpu expects `int32` dtype"
)
Expand All @@ -141,7 +175,6 @@ def from_hp(

scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros

scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
return Int4OpaqueTensor(
Expand Down
Loading