Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 44 additions & 9 deletions torchao/prototype/custom_fp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,26 @@
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
# 1. No encodings are reserved for special values (+/-inf, NaN).
# 2. When downcasting from FP32 to Floatx,
# - Rounding mode is round to nearest, ties to even.
# - Rounding mode is round to nearest, ties to even (default).
# - Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx
# magnitude (sign is preserved).

from enum import Enum

import torch
from torch import Tensor


class RoundingMode(Enum):
"""Rounding modes for floating point quantization.

RN: Round to nearest, ties to even (default)
RS: Stochastic rounding
"""
RN = "round_nearest"
RS = "round_stochastic"


def _n_ones(n: int) -> int:
return (1 << n) - 1

Expand All @@ -24,7 +36,9 @@ def _n_ones(n: int) -> int:
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)


def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
def _f32_to_floatx_unpacked(
x: Tensor, ebits: int, mbits: int, rounding_mode: RoundingMode = RoundingMode.RN
) -> Tensor:
"""Convert FP32 numbers to sub-byte floating point numbers with the given
number of exponent and mantissa bits.

Expand All @@ -38,6 +52,12 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
outside the representable range of Floatx after rounding are clamped to the
maximum Floatx magnitude (sign is preserved).

Args:
x: Input tensor of dtype torch.float
ebits: Number of exponent bits
mbits: Number of mantissa bits
rounding_mode: Rounding mode to use (RN, RS)

Code below is an adaptation of https://fburl.com/code/ciwofcg4

Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
Expand Down Expand Up @@ -111,13 +131,28 @@ def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
# branch 3: stay in normal range, adjust the exponent and round
#
normal_x = x.view(torch.int32)
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
# update exponent, rounding bias part 1
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
val_to_add = (exp_bias - F32_EXP_BIAS) << MBITS_F32

if rounding_mode == RoundingMode.RN:
# Round to nearest, ties to even
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
# update exponent, rounding bias part 1
val_to_add += magic_adder
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
elif rounding_mode == RoundingMode.RS:
# Stochastic rounding
# Add random bits to the discarded precision
rnd = torch.randint_like(normal_x, 0, 1 << (MBITS_F32 - mbits), dtype=torch.int32)
# update exponent
normal_x += val_to_add
# add randomness
normal_x += rnd
else:
raise ValueError(f"Unsupported rounding mode: {rounding_mode}")

# take the bits!
normal_x = normal_x >> (MBITS_F32 - mbits)
normal_x = normal_x.to(torch.uint8)
Expand Down
3 changes: 3 additions & 0 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
NVFP4MMConfig,
)

from torchao.prototype.custom_fp_utils import RoundingMode

# import mx_linear here to register the quantize_ transform logic
# ruff: noqa: I001
import torchao.prototype.mx_formats.mx_linear # noqa: F401
Expand All @@ -22,4 +24,5 @@
"MXFPInferenceConfig",
"NVFP4InferenceConfig",
"NVFP4MMConfig",
"RoundingMode",
]
15 changes: 12 additions & 3 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
triton_quantize_nvfp4,
unpack_uint4,
)
from torchao.prototype.custom_fp_utils import RoundingMode
from torchao.prototype.mx_formats.mx_tensor import (
tensor_size_fp4x2_to_hp,
tensor_size_hp_to_fp4x2,
Expand Down Expand Up @@ -158,6 +159,7 @@ def to_nvfp4(
is_swizzled_scales: bool = False,
use_triton_kernel: bool = False,
act_quant_kwargs: Optional[QuantizeTensorToNVFP4Kwargs] = None,
rounding_mode: RoundingMode = RoundingMode.RN,
):
"""Convert high precision tensor to NVFP4 format.

Expand All @@ -171,6 +173,7 @@ def to_nvfp4(
is_swizzled_scales: If True, store scales in swizzled format for faster matrix multiplication
use_triton_kernel: If True, use Triton kernel for quantization
act_quant_kwargs: If specified, config for quantizing the activation
rounding_mode: Rounding mode to use (RN for round-nearest, RS for stochastic)

Returns:
NVFP4Tensor: Quantized tensor in NVFP4 format
Expand All @@ -183,10 +186,14 @@ def to_nvfp4(
assert K % 16 == 0, (
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
)
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
# Convert RoundingMode enum to boolean for triton kernel
use_stochastic_rounding = rounding_mode == RoundingMode.RS
blockwise_scales, data_lp = triton_quantize_nvfp4(
data_hp, per_tensor_scale, use_stochastic_rounding
)
else:
blockwise_scales, data_lp = nvfp4_quantize(
data_hp, block_size, per_tensor_scale
data_hp, block_size, per_tensor_scale, rounding_mode
)
if is_swizzled_scales:
scale_shape = (math.prod(leading_dims) * M, K // block_size)
Expand Down Expand Up @@ -677,6 +684,7 @@ def nvfp4_quantize(
data_hp: torch.Tensor,
block_size: int = 16,
per_tensor_scale: Optional[torch.Tensor] = None,
rounding_mode: RoundingMode = RoundingMode.RN,
) -> tuple[torch.Tensor, torch.Tensor]:
"""NVIDIA FP4 quantization with UE4M3 scales.

Expand All @@ -688,6 +696,7 @@ def nvfp4_quantize(
block_size: Block size for quantization (must be 16)
per_tensor_amax: Optional pre-computed absolute maximum for calibration.
If provided, uses per-tensor scaling. If None, uses block-wise scaling only.
rounding_mode: Rounding mode to use (RN or RS)

Returns:
tuple: A tuple containing:
Expand Down Expand Up @@ -742,7 +751,7 @@ def nvfp4_quantize(

data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)
data_scaled = data_scaled.view(orig_shape)
data_lp = f32_to_f4_unpacked(data_scaled)
data_lp = f32_to_f4_unpacked(data_scaled, rounding_mode)
# TODO: NotImplementedError: "copy_kernel" not implemented for 'Float4_e2m1fn_x2'
# data_lp = pack_uint4(data_lp).view(torch.float4_e2m1fn_x2)
data_lp = pack_uint4(data_lp)
Expand Down