Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
)
from torchao.quantization.qat.fake_quantize_config import (
Float8FakeQuantizeConfig,
Int4WeightPreshuffledFakeQuantizeConfig,
Int4WeightFakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from torchao.quantization.qat.fake_quantizer import (
Expand Down Expand Up @@ -1985,7 +1985,7 @@ def test_infer_fp8_int4_config(self):
self.assertIsInstance(act_config, Float8FakeQuantizeConfig)
self.assertEqual(act_config.dtype, e4m3_dtype)
self.assertIsInstance(act_config.granularity, PerRow)
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
self.assertEqual(weight_config.group_size, 128)
self.assertEqual(weight_config.activation_dtype, e4m3_dtype)

Expand All @@ -2008,7 +2008,7 @@ def test_infer_int4_weight_only_config(self):
base_config = Int4WeightOnlyConfig(version=2)
(act_config, weight_config) = _infer_fake_quantize_configs(base_config)
self.assertIsNone(act_config)
self.assertIsInstance(weight_config, Int4WeightPreshuffledFakeQuantizeConfig)
self.assertIsInstance(weight_config, Int4WeightFakeQuantizeConfig)
self.assertEqual(weight_config.group_size, 128)
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)

Expand Down Expand Up @@ -2102,7 +2102,7 @@ def test_fbgemm_fp8_int4_preshuffled_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.quantize_int4_preshuffle
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
"""
from fbgemm_gpu.experimental.gen_ai.quantize import (
int4_row_quantize,
Expand Down Expand Up @@ -2184,7 +2184,7 @@ def test_fbgemm_int4_weight_only_primitives(self):
"""
Compare numerics between:
(1) fbgemm_gpu.experimental.gen_ai.quantize.int4_row_quantize_zp
(2) Our reference QAT version in `Int4WeightPreshuffledFakeQuantizer`
(2) Our reference QAT version in `Int4WeightFakeQuantizer`
"""
from fbgemm_gpu.experimental.gen_ai.quantize import (
int4_row_quantize_zp,
Expand Down
7 changes: 3 additions & 4 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ def __post_init__(self):
)


# TODO: rename this config, it actually works for both plain and preshuffled
@dataclass
class Int4WeightPreshuffledFakeQuantizeConfig(FakeQuantizeConfigBase):
class Int4WeightFakeQuantizeConfig(FakeQuantizeConfigBase):
"""
Config for pint4 weight fake quantization that targets the numerics in the following preshuffled kernel:
torch.ops.fbgemm.f8i4bf16_shuffled
Expand Down Expand Up @@ -393,7 +392,7 @@ def _infer_fake_quantize_configs(
raise ValueError(
f"Packing format must be one of {supported_packing_formats}"
)
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
weight_config = Int4WeightFakeQuantizeConfig(
group_size=128,
activation_dtype=torch.bfloat16,
)
Expand Down Expand Up @@ -436,7 +435,7 @@ def _infer_fake_quantize_configs(
dtype=e4m3_dtype,
granularity=PerRow(),
)
weight_config = Int4WeightPreshuffledFakeQuantizeConfig(
weight_config = Int4WeightFakeQuantizeConfig(
group_size=128,
activation_dtype=e4m3_dtype,
)
Expand Down
15 changes: 6 additions & 9 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .fake_quantize_config import (
FakeQuantizeConfigBase,
Float8FakeQuantizeConfig,
Int4WeightPreshuffledFakeQuantizeConfig,
Int4WeightFakeQuantizeConfig,
IntxFakeQuantizeConfig,
)
from .utils import (
Expand Down Expand Up @@ -68,8 +68,8 @@ def from_config(config: FakeQuantizeConfigBase) -> "FakeQuantizerBase":

if isinstance(config, IntxFakeQuantizeConfig):
return IntxFakeQuantizer(config)
elif isinstance(config, Int4WeightPreshuffledFakeQuantizeConfig):
return Int4WeightPreshuffledFakeQuantizer(config)
elif isinstance(config, Int4WeightFakeQuantizeConfig):
return Int4WeightFakeQuantizer(config)
elif isinstance(config, Float8FakeQuantizeConfig):
return Float8FakeQuantizer(config)
elif isinstance(config, NVFP4FakeQuantizeConfig):
Expand Down Expand Up @@ -103,8 +103,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return dq


# TODO: rename this, it also works for plain Int4Tensor
class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
class Int4WeightFakeQuantizer(FakeQuantizerBase):
"""
Generic module for applying int4 fake quantization to a weight tensor,
targeting the following FBGEMM kernels:
Expand All @@ -113,12 +112,10 @@ class Int4WeightPreshuffledFakeQuantizer(FakeQuantizerBase):
torch.ops.fbgemm.bf16i4bf16_rowwise
"""

def __init__(self, config: Int4WeightPreshuffledFakeQuantizeConfig):
def __init__(self, config: Int4WeightFakeQuantizeConfig):
super().__init__()
self.config = config
torch._C._log_api_usage_once(
"torchao.quantization.qat.Int4WeightPreshuffledFakeQuantizer"
)
torch._C._log_api_usage_once("torchao.quantization.qat.Int4WeightFakeQuantizer")

def forward(self, w: torch.Tensor) -> torch.Tensor:
if self.config.activation_dtype == torch.float8_e4m3fn:
Expand Down
Loading