From 6802b79cf69cdd9079522339f6e23d271a3ddfa4 Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 11 Oct 2025 01:29:10 +0900 Subject: [PATCH 1/2] fix enum backend in int4 hqq packing format --- test/prototype/test_awq.py | 6 +++++- torchao/quantization/quant_api.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_awq.py b/test/prototype/test_awq.py index d5318de7ca..22a6d93313 100644 --- a/test/prototype/test_awq.py +++ b/test/prototype/test_awq.py @@ -59,8 +59,12 @@ def forward(self, x): device_to_base_configs = { "cuda": [ Int4WeightOnlyConfig(group_size=128), - # Note: the functionality unit test doesn't work for hqq Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d"), + Int4WeightOnlyConfig( + group_size=128, + int4_packing_format="tile_packed_to_4d", + int4_choose_qparams_algorithm="hqq", + ), ], "cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")], "xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")], diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3bda8f91ab..bef68d4bc0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1156,6 +1156,12 @@ class Int4WeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") + if isinstance(self.int4_packing_format, str): + self.int4_packing_format = Int4PackingFormat(self.int4_packing_format) + if isinstance(self.int4_choose_qparams_algorithm, str): + self.int4_choose_qparams_algorithm = Int4ChooseQParamsAlgorithm( + self.int4_choose_qparams_algorithm + ) # for BC From ad15bb9604aa1f035f41a3b2f184d5b04a5a216d Mon Sep 17 00:00:00 2001 From: youn17 Date: Sat, 11 Oct 2025 02:49:44 +0900 Subject: [PATCH 2/2] revert unhelpful change --- torchao/quantization/quant_api.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bef68d4bc0..3bda8f91ab 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1156,12 +1156,6 @@ class Int4WeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig") - if isinstance(self.int4_packing_format, str): - self.int4_packing_format = Int4PackingFormat(self.int4_packing_format) - if isinstance(self.int4_choose_qparams_algorithm, str): - self.int4_choose_qparams_algorithm = Int4ChooseQParamsAlgorithm( - self.int4_choose_qparams_algorithm - ) # for BC