From 70461d9404411af0990738079f1c48eff3a27ad0 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 23 Sep 2025 13:32:00 -0700 Subject: [PATCH 1/3] Support NVFP4 dynamic per tensor scale **Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in https://github.com/pytorch/ao/issues/2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned] --- .../mx_formats/test_inference_workflow.py | 6 ++++- test/quantization/test_qat.py | 7 +++--- .../mx_formats/inference_workflow.py | 15 +++++++++-- torchao/prototype/mx_formats/nvfp4_tensor.py | 25 ++++++++++++++++--- .../quantization/qat/fake_quantize_config.py | 13 +++++----- 5 files changed, 49 insertions(+), 17 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 988a879b5b..90dc2700ce 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -105,6 +105,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool): ) @pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32]) @pytest.mark.parametrize("use_triton_kernel", [True, False]) +@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False]) @pytest.mark.parametrize( "shapes", [ @@ -126,6 +127,7 @@ def test_inference_workflow_nvfp4( mm_config: NVFP4MMConfig, inpt_dtype: torch.dtype, use_triton_kernel: bool, + use_dynamic_per_tensor_scale: bool, shapes: tuple, ): """ @@ -147,7 +149,9 @@ def test_inference_workflow_nvfp4( m_mx = copy.deepcopy(m) config = NVFP4InferenceConfig( - mm_config=mm_config, use_triton_kernel=use_triton_kernel + mm_config=mm_config, + use_triton_kernel=use_triton_kernel, + use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale, ) quantize_(m_mx, config=config) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index a6ef09e6e8..fd4939232f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -2077,7 +2077,8 @@ def test_infer_int4_weight_only_config(self): self.assertEqual(weight_config.activation_dtype, torch.bfloat16) @unittest.skipIf(not is_sm_at_least_89(), "Need sm89+") - def test_quantize_api_nvfp4(self): + @parametrize("use_per_tensor_scale", [True, False]) + def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool): """ Test the following: quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare")) @@ -2086,8 +2087,8 @@ def test_quantize_api_nvfp4(self): from torchao.prototype.mx_formats import NVFP4InferenceConfig self._test_quantize_api_against_ptq( - NVFP4InferenceConfig(), - target_prepare_sqnr=8, + NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale), + target_prepare_sqnr=12, target_convert_sqnr=float("inf"), ) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 34cf9e9506..39f0725390 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -22,6 +22,7 @@ NVFP4MMConfig, NVFP4Tensor, QuantizeTensorToNVFP4Kwargs, + per_tensor_amax_to_scale, ) from torchao.quantization.transform_module import ( register_quantize_module_handler, @@ -134,7 +135,8 @@ class NVFP4InferenceConfig(AOBaseConfig): This is a specialized configuration for NVIDIA's FP4 format. Configuration parameters: - mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision) - - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: False) + - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True) + - use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True) - Data: float4_e2m1fn_x2 - Scales: float8_e4m3fn - Block size: 16 along the reduction dim @@ -145,6 +147,7 @@ class NVFP4InferenceConfig(AOBaseConfig): mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC use_triton_kernel: bool = True + use_dynamic_per_tensor_scale: bool = True def __post_init__(self): # Validate PyTorch version @@ -175,12 +178,20 @@ def _nvfp4_inference_linear_transform( "Please use bfloat16 or float16 weights, or remove the bias from the linear layer." ) + per_tensor_scale = None + if config.use_dynamic_per_tensor_scale: + tensor_amax = torch.max(torch.abs(weight)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + act_quant_kwargs = None if config.mm_config == NVFP4MMConfig.DYNAMIC: - act_quant_kwargs = QuantizeTensorToNVFP4Kwargs() + act_quant_kwargs = QuantizeTensorToNVFP4Kwargs( + use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale, + ) quantized_weight = NVFP4Tensor.to_nvfp4( weight, + per_tensor_scale=per_tensor_scale, is_swizzled_scales=True, use_triton_kernel=False, # Always use traditional construction for weights act_quant_kwargs=act_quant_kwargs, diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 3f2e8eeef3..1eb12dee90 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -47,6 +47,7 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs): block_size: int = 16 is_swizzled_scales: bool = False use_triton_kernel: bool = False + use_dynamic_per_tensor_scale: bool = False # TODO(future PR): move over to TorchAOBaseTensor's dispatch @@ -245,7 +246,7 @@ def get_hp_scales(self) -> torch.Tensor: return ( scale_e4m3.to(self._orig_dtype) - if not self._per_tensor_scale + if self._per_tensor_scale is None else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype) ) @@ -645,10 +646,15 @@ def nvfp4_linear(func, types, args, kwargs): else: # dynamic quant k = weight_tensor.act_quant_kwargs + if k.use_dynamic_per_tensor_scale: + tensor_amax = torch.max(torch.abs(input_tensor)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = weight_tensor._act_per_tensor_scale input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size, - per_tensor_scale=weight_tensor._act_per_tensor_scale, + per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, ) @@ -672,10 +678,15 @@ def nvfp4_mm(func, types, args, kwargs): else: if not isinstance(input_tensor, NVFP4Tensor): k = weight_tensor.act_quant_kwargs + if k.use_dynamic_per_tensor_scale: + tensor_amax = torch.max(torch.abs(input_tensor)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = weight_tensor._act_per_tensor_scale input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size, - per_tensor_scale=weight_tensor._act_per_tensor_scale, + per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, ) @@ -697,12 +708,18 @@ def nvfp4_addmm(func, types, args, kwargs): else: return torch.addmm(bias, input_tensor, weight_dequant) else: + # TODO: refactor duplicate code if not isinstance(input_tensor, NVFP4Tensor): k = weight_tensor.act_quant_kwargs + if k.use_dynamic_per_tensor_scale: + tensor_amax = torch.max(torch.abs(input_tensor)) + per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) + else: + per_tensor_scale = weight_tensor._act_per_tensor_scale input_tensor = NVFP4Tensor.to_nvfp4( input_tensor, block_size=k.block_size, - per_tensor_scale=weight_tensor._act_per_tensor_scale, + per_tensor_scale=per_tensor_scale, is_swizzled_scales=k.is_swizzled_scales, use_triton_kernel=k.use_triton_kernel, ) diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index ebc9864f3d..336a419af5 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -442,16 +442,15 @@ def _infer_fake_quantize_configs( activation_dtype=e4m3_dtype, ) elif isinstance(base_config, NVFP4InferenceConfig): - # Note: today the PTQ config does not allow the user to specify - # `per_tensor_scales` due to serialization concerns. In the future - # we may add a way to compute these dynamically (for activations), - # but for now QAT will mimic the existing behavior of not having - # `per_tensor_scales` (subject to change) if NVFP4MMConfig.DYNAMIC: - act_config = NVFP4FakeQuantizeConfig(False) + act_config = NVFP4FakeQuantizeConfig( + use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale + ) else: act_config = None - weight_config = NVFP4FakeQuantizeConfig(False) + weight_config = NVFP4FakeQuantizeConfig( + use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale + ) elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig): assert base_config.version >= 2, "Only version 2+ is supported" assert base_config.intx_packing_format == "unpacked_to_int8", ( From 6fc1dab0c4e3fe7596df759f6f212e63ecee017f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 23 Sep 2025 14:48:53 -0700 Subject: [PATCH 2/3] Update on "Support NVFP4 dynamic per tensor scale" **Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in https://github.com/pytorch/ao/issues/2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned] From 757f6fd0f045ec75b471a749c0bdd6495845e982 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 23 Sep 2025 15:01:29 -0700 Subject: [PATCH 3/3] Update on "Support NVFP4 dynamic per tensor scale" **Summary:** This commit adds an option for the existing `NVFP4InferenceConfig` to dynamically compute an appropriate fp32 per tensor scale to support the two level scaling according to the NVFP4 specification: https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/. While two level scaling is supported in `NVFP4Tensor`, today there is no config API for users to call this. The existing `NVFP4InferenceConfig` only supports single level scaling because including an explicit `per_tensor_scale` field would make serialization tricky. In the future, we should add an end-to-end calibration flow so users can compute an appropriate per tensor scale for the activations first, and then pass this to `NVFP4Tensor` as a static scale, similar to the proposal in https://github.com/pytorch/ao/issues/2572. **Test Plan:** ``` pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4 pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4 ``` Also did a quick benchmark before and after: ``` import copy import time import torch from torchao.quantization import quantize_ from torchao.prototype.mx_formats import NVFP4InferenceConfig m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda") m_mx2 = copy.deepcopy(m_mx1) config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False) config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True) quantize_(m_mx1, config=config1) quantize_(m_mx2, config=config2) m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager") m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager") start = time.time() for _ in range(1000): m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("No per_tensor_scale = ", time.time() - start, "seconds") start = time.time() for _ in range(1000): m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)) print("With per_tensor_scale = ", time.time() - start, "seconds") ``` On a single B200: ``` No per_tensor_scale = 1.2855589389801025 seconds With per_tensor_scale = 1.3009123802185059 seconds ``` [ghstack-poisoned]