Skip to content

NVfp4 #2408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

NVfp4 #2408

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion test/prototype/mx_formats/test_mx_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
MXInferenceLinear,
MXLinear,
)
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
from torchao.prototype.mx_formats.mx_subclass import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.utils import compute_error
from torchao.testing.utils import skip_if_rocm
Expand Down Expand Up @@ -441,3 +444,110 @@ def test_inference_subclass(elem_dtype, bias: bool, compile: bool):
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for {elem_dtype} and bias={bias}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@torch.no_grad()
@skip_if_rocm("ROCm float4 gemm require gfx950")
def test_inference_subclass_nvfp4(bias: bool, compile: bool):
"""
Test NVFP4 recipe with scale_dtype=float8_e4m3fn and block_size=16
"""
m = nn.Linear(64, 256, bias=bias, dtype=torch.bfloat16, device="cuda")
m_mx = copy.deepcopy(m)

config = NVFP4InferenceConfig()
quantize_(m_mx, config=config)
if compile:
m_mx = torch.compile(m_mx, fullgraph=True)

x = torch.randn(128, 64, device="cuda", dtype=torch.bfloat16)
y_ref = m(x)
y_mx = m_mx(x)
sqnr = compute_error(y_ref, y_mx)
SQNR_THRESHOLD = 15.0
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required for float4 gemm"
)
@pytest.mark.parametrize("use_gelu", [True, False])
@pytest.mark.parametrize("emulate", [True, False])
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("bias", [True, False])
@torch.no_grad()
@skip_if_rocm("ROCm float4 gemm require gfx950")
def test_nvfp4_matmul_with_amax(
use_gelu: bool, emulate: bool, compile: bool, bias: bool
):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
per_tensor_amax_to_scale,
)

m, k, n = 64, 256, 128

# Create activation tensor
if use_gelu:
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
A = torch.nn.functional.gelu(x)
else:
A = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")

B = torch.randn(n, k, dtype=torch.bfloat16, device="cuda")
bias_tensor = torch.randn(n, dtype=torch.bfloat16, device="cuda") if bias else None

# Compute reference
C_ref = torch.matmul(A, B.t())
if bias:
C_ref = C_ref + bias_tensor

a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
A_nvfp4 = NVFP4Tensor.to_nvfp4(A, per_tensor_scale=a_scale)
B_nvfp4 = NVFP4Tensor.to_nvfp4(B, per_tensor_scale=b_scale)

if emulate:
# Cast back to original dtype and compute
A_emulated = A_nvfp4.to_dtype(A.dtype)
B_emulated = B_nvfp4.to_dtype(B.dtype)
mm = torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
C_emulated = mm(A_emulated, B_emulated.t())
if bias:
C_emulated = C_emulated + bias_tensor
sqnr = compute_error(C_ref, C_emulated)
else:
if bias:
linear_fn = (
torch.compile(torch.nn.functional.linear, fullgraph=True)
if compile
else torch.nn.functional.linear
)
C_nvfp4 = linear_fn(A_nvfp4, B_nvfp4, bias_tensor)
else:
mm = (
torch.compile(torch.matmul, fullgraph=True) if compile else torch.matmul
)
C_nvfp4 = mm(A_nvfp4, B_nvfp4.t())
sqnr = compute_error(C_ref, C_nvfp4)

# Check quality threshold
SQNR_THRESHOLD = 16.0
assert sqnr >= SQNR_THRESHOLD, (
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, emulate={emulate}, compile={compile}, bias={bias}"
)
55 changes: 55 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torchao.prototype.mx_formats.constants import (
DTYPE_FP6_E2M3,
DTYPE_FP6_E3M2,
F4_E2M1_MAX,
SUPPORTED_ELEM_DTYPES,
)
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
Expand Down Expand Up @@ -591,3 +592,57 @@ def to_f8(x):
torch.testing.assert_close(
data_in_range_f8_c, data_out_of_range_f8_c, atol=0, rtol=0
)


@pytest.mark.parametrize(
"dtype,shape,use_per_tensor_scale",
[
(torch.bfloat16, (32, 64), False),
(torch.float32, (64, 128), False),
(torch.bfloat16, (128, 256), False),
(torch.bfloat16, (64, 128), True),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not TORCH_VERSION_AT_LEAST_2_8, reason="torch.compile requires PyTorch 2.8+"
)
def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale):
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
per_tensor_amax_to_scale,
)

x = torch.randn(shape, dtype=dtype, device="cuda")
if use_per_tensor_scale:
tensor_amax = torch.max(torch.abs(x))
scale = per_tensor_amax_to_scale(tensor_amax)
else:
scale = None

x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale)
x_reconstructed = x_nvfp4.to_dtype(dtype)

def assert_sqnr_gt_threshold(orig, new, threshold):
sqnr = compute_error(orig, new)
if torch.all(torch.isnan(sqnr)):
# if both operands are full of zeroes, sqnr is nan and this is ok
# test for this explicitly
assert torch.all(orig == 0) and torch.all(new == 0)
else:
assert sqnr >= threshold

reconstructed_amax = x_nvfp4.get_scales().view(shape[0], -1, 1) * F4_E2M1_MAX
max_abs = torch.amax(
torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1
).unsqueeze(-1)

assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0)
assert_sqnr_gt_threshold(x, x_reconstructed, 8.0)

assert x.shape == x_reconstructed.shape, (
f"Shape mismatch: {x.shape} vs {x_reconstructed.shape}"
)
assert x.dtype == x_reconstructed.dtype, (
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
)
6 changes: 5 additions & 1 deletion torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
)

# Note: Prototype and subject to change
from torchao.prototype.mx_formats.mx_subclass import MXFPInferenceConfig
from torchao.prototype.mx_formats.mx_subclass import (
MXFPInferenceConfig,
NVFP4InferenceConfig,
)

# import mx_linear here to register the quantize_ transform logic
# ruff: noqa: I001
Expand All @@ -18,4 +21,5 @@
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPInferenceConfig",
"NVFP4InferenceConfig",
]
6 changes: 3 additions & 3 deletions torchao/prototype/mx_formats/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
assert block_size == 32, (
f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
assert block_size in [16, 32], (
f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}"
)
valid_dtypes = [torch.float8_e4m3fn]
valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2]
assert elem_dtype in valid_dtypes, (
f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
)
Expand Down
123 changes: 113 additions & 10 deletions torchao/prototype/mx_formats/mx_subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
# LICENSE file in the root directory of this source tree.

import types
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Optional

import torch

import torchao
from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats import (
MXGemmKernelChoice,
Expand All @@ -20,11 +19,16 @@
_validate_gemm_kernel_choice,
)
from torchao.prototype.mx_formats.mx_tensor import MXTensor
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
from torchao.quantization.quant_api import to_linear_activation_quantized
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_8,
is_sm_at_least_100,
)


# Note: This API is extra prototype and will change in the future
Expand Down Expand Up @@ -63,16 +67,13 @@ class MXFPInferenceConfig(AOBaseConfig):

block_size: int = 32

# Dtypes for Input and Weights
# Dtypes for Input and Weights, supports Fp8 and Fp4 formats
activation_dtype: torch.dtype = torch.float8_e4m3fn
weight_dtype: torch.dtype = torch.float8_e4m3fn

# Which kernel to run for mm
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS

# Set some magic perf settings
set_inductor_config: bool = False

def __post_init__(self):
assert self.activation_dtype == self.weight_dtype, (
"For now - we only support matching input/weight dtypes."
Expand Down Expand Up @@ -115,8 +116,6 @@ def _mx_inference_linear_transform(
# TODO Sm120 has slightly more restrictive reqs
# TODO handle AMD
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_dtype = config.activation_dtype
weight_dtype = config.weight_dtype
Expand Down Expand Up @@ -151,7 +150,111 @@ def _mx_inference_linear_transform(
return module


def _get_nvfp4_dtype():
"""Factory function for NVFP4 dtype defaults."""
if not TORCH_VERSION_AT_LEAST_2_8:
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")
return torch.float4_e2m1fn_x2


@dataclass
class NVFP4InferenceConfig(AOBaseConfig):
"""
NVIDIA FP4 (NVFP4) Inference Quantization Configuration

This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.
It provides defaults optimized for NVFP4:
- Data: float4_e2m1fn_x2
- Scales: float8_e4m3fn (UE4M3)
- Block size: 16 (required for NVFP4)
- CUBLAS kernel (optimized for VEC16_UE4M3)
"""

block_size: int = 16 # NVFP4 requires block size 16

# NVFP4 uses FP4 data
activation_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)
weight_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)

# NVFP4 uses E4M3 scales
scale_dtype: torch.dtype = torch.float8_e4m3fn

# CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS

def __post_init__(self):
# Validate NVFP4 constraints
if not TORCH_VERSION_AT_LEAST_2_8:
raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")

assert self.activation_dtype == torch.float4_e2m1fn_x2, (
f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"
)
assert self.weight_dtype == torch.float4_e2m1fn_x2, (
f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"
)
assert self.scale_dtype == torch.float8_e4m3fn, (
f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"
)
assert self.block_size == 16, (
f"NVFP4 requires block_size=16, got {self.block_size}"
)


def _input_activation_quant_func_nvfp4(
x: torch.Tensor,
block_size: int = 16,
scale: Optional[torch.Tensor] = None,
):
"""NVFP4-specific activation quantization function"""
# TODO: scale for static quant
activation = NVFP4Tensor.to_nvfp4(
x,
block_size=block_size,
)
return activation


@register_quantize_module_handler(NVFP4InferenceConfig)
def _nvfp4_inference_linear_transform(
module: torch.nn.Module, config: NVFP4InferenceConfig
):
"""Quantization handler for NVFP4InferenceConfig"""
assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"

weight = module.weight
assert weight.dtype == torch.bfloat16, (
f"Only supporting bf16 out dtype for now, got {weight.dtype}"
)

# Convert weight to NVFP4 Tensor
quantized_weight = NVFP4Tensor.to_nvfp4(
weight,
block_size=config.block_size,
)

input_quant_func = _input_activation_quant_func_nvfp4
input_quant_kwargs = {
"block_size": config.block_size,
"scale": None,
}

quantized_weight = to_linear_activation_quantized(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just write the logic here instead of using to_linear_activation_quantized? I remember same feedback on the mxfp4 inference tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can't just move the logic out here, the entirety of the forward behavior has to be "wrapped" by the subclass. Currently there are two ways to do that, without changing nn.modules.

  1. Like above; this is subclass composition
  2. The other is to copy the same behavior into the implementations of the ops,
    e.g.

NVFP4's dispatch would need to copy:

@implements([torch.nn.functional.linear, aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight_tensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if isinstance(weight_tensor, LinearActivationQuantizedTensor):
return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias)
raise NotImplementedError(
"LinearActivationQuantizedTensor: No specialized dispatch found for linear op"
)
@implements([aten.mm.default, aten.addmm.default])
def _(func, types, args, kwargs):
if not args[0].is_floating_point():
raise NotImplementedError(
"LinearActivationQuantizedTensor: expecting a floating point input"
)
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_tensor, bias = (
args[1],
args[2],
args[0],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
return func(bias, qtensor, original_weight_tensor)
else:
# aten.mm.default
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_tensor = (
args[0],
args[1],
)
input_quant_func = weight_tensor.input_quant_func
original_weight_tensor = weight_tensor.original_weight_tensor
qtensor = input_quant_func(input_tensor, **weight_tensor.quant_kwargs)
return func(qtensor, original_weight_tensor)

Not the end of the world. But for some subclasses that serve dual purpose (dyanmic + weight only, + static, + training) it can be alot of switch statements in the ops as opposed to having the base subclass + some sugar

quantized_weight, input_quant_func, quant_kwargs=input_quant_kwargs
)

module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals(
[MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]
[
MXTensor,
NVFP4Tensor,
MXGemmKernelChoice,
_input_activation_quant_func_mxfp,
_input_activation_quant_func_nvfp4,
]
)
Loading
Loading