From 5ac446f783d5623de58a032c5dcfe1dc90e8d7cb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Nov 2025 13:24:18 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 6 +- .../mx_formats/test_inference_workflow.py | 42 +++++----- .../mx_formats/test_mx_serialization.py | 6 +- .../prototype/mx_formats/test_nvfp4_tensor.py | 18 ++--- test/quantization/test_qat.py | 16 ++-- torchao/prototype/mx_formats/README.md | 36 +++++---- torchao/prototype/mx_formats/__init__.py | 8 +- .../mx_formats/inference_workflow.py | 76 ++++++++++++++----- torchao/prototype/mx_formats/nvfp4_tensor.py | 6 -- torchao/prototype/qat/nvfp4.py | 6 +- .../quantization/qat/fake_quantize_config.py | 18 ++--- 11 files changed, 137 insertions(+), 101 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index e1eaa43c1b..188bb46224 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -40,8 +40,7 @@ import torchao from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, ) from torchao.prototype.mx_formats.utils import to_blocked from torchao.quantization.quant_api import ( @@ -445,8 +444,7 @@ def run( kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "nvfp4": - config = NVFP4InferenceConfig( - mm_config=NVFP4MMConfig.DYNAMIC, + config = NVFP4DynamicActivationNVFP4WeightConfig( use_dynamic_per_tensor_scale=False, ) else: diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index fd03eec29f..6c89f0efa8 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -14,8 +14,8 @@ from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, + NVFP4WeightOnlyConfig, ) from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import KernelPreference @@ -138,9 +138,7 @@ def test_inference_workflow_mx( ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("compile", [True, False]) -@pytest.mark.parametrize( - "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] -) +@pytest.mark.parametrize("quant_type", ["dynamic", "weight_only"]) @pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("use_triton_kernel", [True, False]) @pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False]) @@ -164,7 +162,7 @@ def test_inference_workflow_mx( def test_inference_workflow_nvfp4( bias: bool, compile: bool, - mm_config: NVFP4MMConfig, + quant_type: str, inpt_dtype: torch.dtype, use_triton_kernel: bool, use_dynamic_per_tensor_scale: bool, @@ -177,14 +175,16 @@ def test_inference_workflow_nvfp4( Tests both DYNAMIC and WEIGHT_ONLY mm_config modes """ # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs - if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + if quant_type == "dynamic" and not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") if bias and inpt_dtype == torch.float32: pytest.xfail("Bias is not supported when module weight is in fp32") - if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: - pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + if quant_type == "weight_only" and compile: + pytest.skip("TODO: weight_only quant currently errors w/ compile") + if quant_type == "weight_only" and use_triton_kernel: + pytest.skip("unsupported configuration") if use_inference_mode and ( shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel @@ -200,11 +200,15 @@ def test_inference_workflow_nvfp4( m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda") m_mx = copy.deepcopy(m) - config = NVFP4InferenceConfig( - mm_config=mm_config, - use_triton_kernel=use_triton_kernel, - use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale, - ) + if quant_type == "dynamic": + config = NVFP4DynamicActivationNVFP4WeightConfig( + use_triton_kernel=use_triton_kernel, + use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale, + ) + else: + config = NVFP4WeightOnlyConfig( + use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale, + ) quantize_(m_mx, config=config) if compile: @@ -216,7 +220,7 @@ def test_inference_workflow_nvfp4( y_ref = m(x) - if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY: + if use_triton_kernel and quant_type == "dynamic": with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result: y_mx = m_mx(x) assert result["found"], "Expected quantize_nvfp4 kernel to be found" @@ -229,14 +233,14 @@ def test_inference_workflow_nvfp4( sqnr = compute_error(y_ref, y_mx) - if mm_config == NVFP4MMConfig.WEIGHT_ONLY: + if quant_type == "weight_only": SQNR_THRESHOLD = 18.0 else: SQNR_THRESHOLD = 15.0 assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}" assert sqnr >= SQNR_THRESHOLD, ( - f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}" + f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, {quant_type=}" ) @@ -273,9 +277,7 @@ def test_narrow_similar_to_vllm(self): reason="torch.compile requires PyTorch 2.8+", ) def test_nvfp4_quantize_3d_param_similar_to_vllm(self): - config = NVFP4InferenceConfig( - mm_config=NVFP4MMConfig.WEIGHT_ONLY, - use_triton_kernel=False, + config = NVFP4WeightOnlyConfig( use_dynamic_per_tensor_scale=False, ) self._test_quantize_3d_param_similar_to_vllm(config) diff --git a/test/prototype/mx_formats/test_mx_serialization.py b/test/prototype/mx_formats/test_mx_serialization.py index 9649da98f7..a109b63aef 100644 --- a/test/prototype/mx_formats/test_mx_serialization.py +++ b/test/prototype/mx_formats/test_mx_serialization.py @@ -14,8 +14,7 @@ from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, ) from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import KernelPreference @@ -48,8 +47,7 @@ def test_serialization(recipe_name): ) else: assert recipe_name == "nvfp4", "unsupported" - config = NVFP4InferenceConfig( - mm_config=NVFP4MMConfig.DYNAMIC, + config = NVFP4DynamicActivationNVFP4WeightConfig( use_triton_kernel=False, use_dynamic_per_tensor_scale=False, ) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index e098edb745..2f734cef2c 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -12,9 +12,6 @@ from torchao.prototype.mx_formats.constants import ( F4_E2M1_MAX, ) -from torchao.prototype.mx_formats.inference_workflow import ( - NVFP4MMConfig, -) from torchao.prototype.mx_formats.nvfp4_tensor import ( NVFP4Tensor, QuantizeTensorToNVFP4Kwargs, @@ -422,7 +419,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): ) @pytest.mark.parametrize("use_gelu", [True, False]) @pytest.mark.parametrize( - "mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY] + "quant_type", + ["dynamic", "weight_only"], ) @pytest.mark.parametrize("compile", [False]) @pytest.mark.parametrize("bias", [True, False]) @@ -448,7 +446,7 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): ) def test_nvfp4_matmul_with_amax( use_gelu: bool, - mm_config: NVFP4MMConfig, + quant_type: str, compile: bool, bias: bool, inpt_dtype: torch.dtype, @@ -456,14 +454,14 @@ def test_nvfp4_matmul_with_amax( shapes: tuple, ): # DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs - if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100(): + if quant_type == "dynamic" and not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm") if bias and inpt_dtype == torch.float32: pytest.xfail("Bias is not supported when module weight is in fp32") - if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile: - pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile") + if quant_type == "weight_only" and compile: + pytest.skip("TODO: weight_only currently errors w/ compile") m, k, n = shapes @@ -483,7 +481,7 @@ def test_nvfp4_matmul_with_amax( a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A))) b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B))) act_quant_kwargs = None - if mm_config == NVFP4MMConfig.DYNAMIC: + if quant_type == "dynamic": act_quant_kwargs = QuantizeTensorToNVFP4Kwargs() A_nvfp4 = NVFP4Tensor.to_nvfp4( A, @@ -509,7 +507,7 @@ def test_nvfp4_matmul_with_amax( sqnr = compute_error(C_ref, C_nvfp4) SQNR_THRESHOLD = 16.0 assert sqnr >= SQNR_THRESHOLD, ( - f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}" + f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, {quant_type=}, compile={compile}, bias={bias}" ) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index db33561fa9..c5b0fef069 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -2082,13 +2082,15 @@ def test_infer_int4_weight_only_config(self): def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool): """ Test the following: - quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare")) - quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert")) + quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="prepare")) + quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="convert")) """ - from torchao.prototype.mx_formats import NVFP4InferenceConfig + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig self._test_quantize_api_against_ptq( - NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), + NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=use_per_tensor_scale + ), target_prepare_sqnr=float("inf"), target_convert_sqnr=float("inf"), ) @@ -2100,7 +2102,7 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): """ Test QAT with `NVFP4FakeQuantizeConfig`. """ - from torchao.prototype.mx_formats import NVFP4InferenceConfig + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig from torchao.prototype.qat import NVFP4FakeQuantizeConfig torch.manual_seed(self.SEED) @@ -2108,7 +2110,9 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool): baseline_model = copy.deepcopy(m) quantize_( baseline_model, - NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), + NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=use_per_tensor_scale + ), ) qat_config = QATConfig( activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale), diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index f4a9ff1045..e644fee8fe 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -109,8 +109,8 @@ from torchao.quantization import quantize_ import torchao.prototype.mx_formats from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, + NVFP4WeightOnlyConfig, ) from torchao.quantization.quantize_.common import KernelPreference @@ -129,6 +129,27 @@ quantize_(m_mxfp8, config=config) m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True) y_mxfp8 = m_mxfp8(x) +# nvfp4 dynamic quant + +m_nvfp4 = copy.deepcopy(m) +config = NVFP4DynamicActivationNVFP4WeightConfig( + use_dynamic_per_tensor_scale=True, + use_triton_kernel=True, +) +quantize_(m_nvfp4, config=config) +m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True) +y_nvfp4 = m_nvfp4(x) + +# nvfp4 weight-only quant + +m_nvfp4_wo = copy.deepcopy(m) +config = NVFP4WeightOnlyConfig( + use_dynamic_per_tensor_scale=True, +) +quantize_(m_nvfp4_wo, config=config) +m_nvfp4_wo = torch.compile(m_nvfp4_wo, fullgraph=True) +y_nvfp4 = m_nvfp4_wo(x) + # mxfp4 m_mxfp4 = copy.deepcopy(m) @@ -140,17 +161,6 @@ config = MXDynamicActivationMXWeightConfig( quantize_(m_mxfp4, config=config) m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True) y_mxfp4 = m_mxfp4(x) - -# nvfp4 - -m_nvfp4 = copy.deepcopy(m) -config = NVFP4InferenceConfig( - mm_config=NVFP4MMConfig.DYNAMIC, - use_dynamic_per_tensor_scale=True, -) -quantize_(m_nvfp4, config=config) -m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True) -y_nvfp4 = m_nvfp4(x) ``` ## MXTensor diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index bd0d05eee1..4c845847f7 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -6,8 +6,8 @@ # Note: Prototype and subject to change from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, + NVFP4WeightOnlyConfig, ) # import mx_linear here to register the quantize_ transform logic @@ -18,6 +18,6 @@ "MXLinearConfig", "MXLinearRecipeName", "MXDynamicActivationMXWeightConfig", - "NVFP4InferenceConfig", - "NVFP4MMConfig", + "NVFP4DynamicActivationNVFP4WeightConfig", + "NVFP4WeightOnlyConfig", ] diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 6cd011fc0a..c771015ccd 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -20,7 +20,6 @@ ScaleCalculationMode, ) from torchao.prototype.mx_formats.nvfp4_tensor import ( - NVFP4MMConfig, NVFP4Tensor, QuantizeTensorToNVFP4Kwargs, per_tensor_amax_to_scale, @@ -104,13 +103,12 @@ def _mx_inference_linear_transform( @dataclass -class NVFP4InferenceConfig(AOBaseConfig): +class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig): """ NVIDIA FP4 (NVFP4) Inference Quantization Configuration This is a specialized configuration for NVIDIA's FP4 format. Configuration parameters: - - mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision) - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True) - use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True) - Data: float4_e2m1fn_x2 @@ -121,25 +119,25 @@ class NVFP4InferenceConfig(AOBaseConfig): must satisfy M % 128 == 0 and K % 64 == 0. Will automatically fallback when constraints aren't met. """ - mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC use_triton_kernel: bool = True use_dynamic_per_tensor_scale: bool = True def __post_init__(self): # Validate PyTorch version if not torch_version_at_least("2.8.0"): - raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later") + raise RuntimeError( + "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later" + ) -@register_quantize_module_handler(NVFP4InferenceConfig) +@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig) def _nvfp4_inference_linear_transform( - module: torch.nn.Linear, config: NVFP4InferenceConfig + module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig ): - """Quantization handler for NVFP4InferenceConfig""" - if config.mm_config == NVFP4MMConfig.DYNAMIC: - assert is_sm_at_least_100(), ( - "NVFP4 DYNAMIC mode is only supported on sm100+ machines" - ) + """Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig""" + assert is_sm_at_least_100(), ( + "NVFP4 DYNAMIC mode is only supported on sm100+ machines" + ) weight = module.weight @@ -159,13 +157,11 @@ def _nvfp4_inference_linear_transform( tensor_amax = torch.max(torch.abs(weight)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) - act_quant_kwargs = None - if config.mm_config == NVFP4MMConfig.DYNAMIC: - act_quant_kwargs = QuantizeTensorToNVFP4Kwargs( - use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale, - use_triton_kernel=config.use_triton_kernel, - is_swizzled_scales=True, - ) + act_quant_kwargs = QuantizeTensorToNVFP4Kwargs( + use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale, + use_triton_kernel=config.use_triton_kernel, + is_swizzled_scales=True, + ) quantized_weight = NVFP4Tensor.to_nvfp4( weight, @@ -181,11 +177,51 @@ def _nvfp4_inference_linear_transform( return module +@dataclass +class NVFP4WeightOnlyConfig(AOBaseConfig): + use_dynamic_per_tensor_scale: bool = True + + def __post_init__(self): + # Validate PyTorch version + if not torch_version_at_least("2.8.0"): + raise RuntimeError( + "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later" + ) + + +@register_quantize_module_handler(NVFP4WeightOnlyConfig) +def _nvfp4_weight_only_linear_transform( + module: torch.nn.Linear, config: NVFP4WeightOnlyConfig +): + """Quantization handler for NVFP4WeightOnlyConfig""" + weight = module.weight + + if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: + raise RuntimeError( + f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}" + ) + + per_tensor_scale = None + if config.use_dynamic_per_tensor_scale: + tensor_amax = torch.max(torch.abs(weight)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + + quantized_weight = NVFP4Tensor.to_nvfp4( + weight, + per_tensor_scale=per_tensor_scale, + is_swizzled_scales=True, + act_quant_kwargs=None, + ) + # Set triton preference after construction + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + torch.serialization.add_safe_globals( [ MXTensor, NVFP4Tensor, - NVFP4MMConfig, QuantizeTensorToMXKwargs, QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 69aa62afd4..8dbdc5ab15 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -6,7 +6,6 @@ import math from dataclasses import dataclass -from enum import Enum from typing import Optional import torch @@ -40,11 +39,6 @@ aten = torch.ops.aten -class NVFP4MMConfig(Enum): - DYNAMIC = "dynamic" - WEIGHT_ONLY = "weight_only" - - @dataclass class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs): block_size: int = 16 diff --git a/torchao/prototype/qat/nvfp4.py b/torchao/prototype/qat/nvfp4.py index 396389d22e..0a2316621f 100644 --- a/torchao/prototype/qat/nvfp4.py +++ b/torchao/prototype/qat/nvfp4.py @@ -73,7 +73,7 @@ def forward( use_triton_kernel=False, ) - # Follow `NVFP4InferenceConfig`, always use traditional construction + # Follow `NVFP4DynamicActivationNVFP4WeightConfig`, always use traditional construction # for weights and set `use_triton_kernel` afterwards weight.use_triton_kernel = weight_config.use_triton_kernel @@ -112,9 +112,9 @@ class NVFP4FakeQuantizedLinear(torch.nn.Linear): Example usage:: from torchao.quantization import quantize_ - from torchao.prototype.mx_formats import NVFP4InferenceConfig + from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig - base_config = NVFP4InferenceConfig() + base_config = NVFP4DynamicActivationNVFP4WeightConfig() quantize_(model, QATConfig(base_config, step="prepare")) # Model contains `NVFP4FakeQuantizedLinear` now diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index 3a1c7c78f1..b0dffb17ae 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -357,8 +357,7 @@ def _infer_fake_quantize_configs( # TODO: rewrite using registration API so we don't need to import here # avoid circular imports from torchao.prototype.mx_formats import ( - NVFP4InferenceConfig, - NVFP4MMConfig, + NVFP4DynamicActivationNVFP4WeightConfig, ) from torchao.prototype.qat import ( NVFP4FakeQuantizeConfig, @@ -441,15 +440,12 @@ def _infer_fake_quantize_configs( group_size=128, activation_dtype=e4m3_dtype, ) - elif isinstance(base_config, NVFP4InferenceConfig): - if NVFP4MMConfig.DYNAMIC: - act_config = NVFP4FakeQuantizeConfig( - use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale, - use_swizzled_scales=False, - use_triton_kernel=False, - ) - else: - act_config = None + elif isinstance(base_config, NVFP4DynamicActivationNVFP4WeightConfig): + act_config = NVFP4FakeQuantizeConfig( + use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale, + use_swizzled_scales=False, + use_triton_kernel=False, + ) weight_config = NVFP4FakeQuantizeConfig( use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale, use_swizzled_scales=True,