From d8792dd905b6bd81ae61d85f7862f412a4b8f7e3 Mon Sep 17 00:00:00 2001 From: Syed Tousif Ahmed Date: Mon, 24 Nov 2025 10:01:15 -0800 Subject: [PATCH] Implements rounding mode for NVFP4 tensor --- torchao/prototype/custom_fp_utils.py | 53 ++++++++++++++++---- torchao/prototype/mx_formats/__init__.py | 3 ++ torchao/prototype/mx_formats/nvfp4_tensor.py | 15 ++++-- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 3d8de6f0de..2a6939023e 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -8,14 +8,26 @@ # It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: # 1. No encodings are reserved for special values (+/-inf, NaN). # 2. When downcasting from FP32 to Floatx, -# - Rounding mode is round to nearest, ties to even. +# - Rounding mode is round to nearest, ties to even (default). # - Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx # magnitude (sign is preserved). +from enum import Enum + import torch from torch import Tensor +class RoundingMode(Enum): + """Rounding modes for floating point quantization. + + RN: Round to nearest, ties to even (default) + RS: Stochastic rounding + """ + RN = "round_nearest" + RS = "round_stochastic" + + def _n_ones(n: int) -> int: return (1 << n) - 1 @@ -24,7 +36,9 @@ def _n_ones(n: int) -> int: F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) -def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: +def _f32_to_floatx_unpacked( + x: Tensor, ebits: int, mbits: int, rounding_mode: RoundingMode = RoundingMode.RN +) -> Tensor: """Convert FP32 numbers to sub-byte floating point numbers with the given number of exponent and mantissa bits. @@ -38,6 +52,12 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: outside the representable range of Floatx after rounding are clamped to the maximum Floatx magnitude (sign is preserved). + Args: + x: Input tensor of dtype torch.float + ebits: Number of exponent bits + mbits: Number of mantissa bits + rounding_mode: Rounding mode to use (RN, RS) + Code below is an adaptation of https://fburl.com/code/ciwofcg4 Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 @@ -111,13 +131,28 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: # branch 3: stay in normal range, adjust the exponent and round # normal_x = x.view(torch.int32) - # resulting mantissa is odd - mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 - # update exponent, rounding bias part 1 - val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder - normal_x += val_to_add - # rounding bias part 2 - normal_x += mant_odd + val_to_add = (exp_bias - F32_EXP_BIAS) << MBITS_F32 + + if rounding_mode == RoundingMode.RN: + # Round to nearest, ties to even + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add += magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + elif rounding_mode == RoundingMode.RS: + # Stochastic rounding + # Add random bits to the discarded precision + rnd = torch.randint_like(normal_x, 0, 1 << (MBITS_F32 - mbits), dtype=torch.int32) + # update exponent + normal_x += val_to_add + # add randomness + normal_x += rnd + else: + raise ValueError(f"Unsupported rounding mode: {rounding_mode}") + # take the bits! normal_x = normal_x >> (MBITS_F32 - mbits) normal_x = normal_x.to(torch.uint8) diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index c7a4c47f9d..ed76402b2d 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -11,6 +11,8 @@ NVFP4MMConfig, ) +from torchao.prototype.custom_fp_utils import RoundingMode + # import mx_linear here to register the quantize_ transform logic # ruff: noqa: I001 import torchao.prototype.mx_formats.mx_linear # noqa: F401 @@ -22,4 +24,5 @@ "MXFPInferenceConfig", "NVFP4InferenceConfig", "NVFP4MMConfig", + "RoundingMode", ] diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 18f05290e5..87a879507e 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -20,6 +20,7 @@ triton_quantize_nvfp4, unpack_uint4, ) +from torchao.prototype.custom_fp_utils import RoundingMode from torchao.prototype.mx_formats.mx_tensor import ( tensor_size_fp4x2_to_hp, tensor_size_hp_to_fp4x2, @@ -158,6 +159,7 @@ def to_nvfp4( is_swizzled_scales: bool = False, use_triton_kernel: bool = False, act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None, + rounding_mode: RoundingMode = RoundingMode.RN, ): """Convert high precision tensor to NVFP4 format. @@ -171,6 +173,7 @@ def to_nvfp4( is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication use_triton_kernel: If True, use Triton kernel for quantization act_quant_kwargs: If specified, config for quantizing the activation + rounding_mode: Rounding mode to use (RN for round-nearest, RS for stochastic) Returns: NVFP4Tensor: Quantized tensor in NVFP4 format @@ -183,10 +186,14 @@ def to_nvfp4( assert K % 16 == 0, ( f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}" ) - blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale) + # Convert RoundingMode enum to boolean for triton kernel + use_stochastic_rounding = rounding_mode == RoundingMode.RS + blockwise_scales, data_lp = triton_quantize_nvfp4( + data_hp, per_tensor_scale, use_stochastic_rounding + ) else: blockwise_scales, data_lp = nvfp4_quantize( - data_hp, block_size, per_tensor_scale + data_hp, block_size, per_tensor_scale, rounding_mode ) if is_swizzled_scales: scale_shape = (math.prod(leading_dims) * M, K // block_size) @@ -677,6 +684,7 @@ def nvfp4_quantize( data_hp: torch.Tensor, block_size: int = 16, per_tensor_scale: Optional[torch.Tensor] = None, + rounding_mode: RoundingMode = RoundingMode.RN, ) -> tuple[torch.Tensor, torch.Tensor]: """NVIDIA FP4 quantization with UE4M3 scales. @@ -688,6 +696,7 @@ def nvfp4_quantize( block_size: Block size for quantization (must be 16) per_tensor_amax: Optional pre-computed absolute maximum for calibration. If provided, uses per-tensor scaling. If None, uses block-wise scaling only. + rounding_mode: Rounding mode to use (RN or RS) Returns: tuple: A tuple containing: @@ -742,7 +751,7 @@ def nvfp4_quantize( data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX) data_scaled = data_scaled.view(orig_shape) - data_lp = f32_to_f4_unpacked(data_scaled) + data_lp = f32_to_f4_unpacked(data_scaled, rounding_mode) # TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2' # data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2) data_lp = pack_uint4(data_lp)