Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
)
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("use_triton_kernel", [True, False])
@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False])
@pytest.mark.parametrize(
"shapes",
[
Expand All @@ -126,6 +127,7 @@ def test_inference_workflow_nvfp4(
mm_config: NVFP4MMConfig,
inpt_dtype: torch.dtype,
use_triton_kernel: bool,
use_dynamic_per_tensor_scale: bool,
shapes: tuple,
):
"""
Expand All @@ -147,7 +149,9 @@ def test_inference_workflow_nvfp4(
m_mx = copy.deepcopy(m)

config = NVFP4InferenceConfig(
mm_config=mm_config, use_triton_kernel=use_triton_kernel
mm_config=mm_config,
use_triton_kernel=use_triton_kernel,
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
)
quantize_(m_mx, config=config)

Expand Down
7 changes: 4 additions & 3 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2077,7 +2077,8 @@ def test_infer_int4_weight_only_config(self):
self.assertEqual(weight_config.activation_dtype, torch.bfloat16)

@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
def test_quantize_api_nvfp4(self):
@parametrize("use_per_tensor_scale", [True, False])
def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
"""
Test the following:
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
Expand All @@ -2086,8 +2087,8 @@ def test_quantize_api_nvfp4(self):
from torchao.prototype.mx_formats import NVFP4InferenceConfig

self._test_quantize_api_against_ptq(
NVFP4InferenceConfig(),
target_prepare_sqnr=8,
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
target_prepare_sqnr=12,
target_convert_sqnr=float("inf"),
)

Expand Down
15 changes: 13 additions & 2 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
NVFP4MMConfig,
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
per_tensor_amax_to_scale,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
Expand Down Expand Up @@ -134,7 +135,8 @@ class NVFP4InferenceConfig(AOBaseConfig):
This is a specialized configuration for NVIDIA's FP4 format.
Configuration parameters:
- mm_config: NVFP4MMConfig, which can be set to DYNAMIC or WEIGHT_ONLY (emulated mm in high precision)
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: False)
- use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True)
- use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True)
- Data: float4_e2m1fn_x2
- Scales: float8_e4m3fn
- Block size: 16 along the reduction dim
Expand All @@ -145,6 +147,7 @@ class NVFP4InferenceConfig(AOBaseConfig):

mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC
use_triton_kernel: bool = True
use_dynamic_per_tensor_scale: bool = True

def __post_init__(self):
# Validate PyTorch version
Expand Down Expand Up @@ -175,12 +178,20 @@ def _nvfp4_inference_linear_transform(
"Please use bfloat16 or float16 weights, or remove the bias from the linear layer."
)

per_tensor_scale = None
if config.use_dynamic_per_tensor_scale:
tensor_amax = torch.max(torch.abs(weight))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)

act_quant_kwargs = None
if config.mm_config == NVFP4MMConfig.DYNAMIC:
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs(
use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale,
)

quantized_weight = NVFP4Tensor.to_nvfp4(
weight,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=True,
use_triton_kernel=False, # Always use traditional construction for weights
act_quant_kwargs=act_quant_kwargs,
Expand Down
25 changes: 21 additions & 4 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class QuantizeTensorToNVFP4Kwargs(QuantizeTensorKwargs):
block_size: int = 16
is_swizzled_scales: bool = False
use_triton_kernel: bool = False
use_dynamic_per_tensor_scale: bool = False


# TODO(future PR): move over to TorchAOBaseTensor's dispatch
Expand Down Expand Up @@ -245,7 +246,7 @@ def get_hp_scales(self) -> torch.Tensor:

return (
scale_e4m3.to(self._orig_dtype)
if not self._per_tensor_scale
if self._per_tensor_scale is None
else self._per_tensor_scale * scale_e4m3.to(self._orig_dtype)
)

Expand Down Expand Up @@ -645,10 +646,15 @@ def nvfp4_linear(func, types, args, kwargs):
else:
# dynamic quant
k = weight_tensor.act_quant_kwargs
if k.use_dynamic_per_tensor_scale:
tensor_amax = torch.max(torch.abs(input_tensor))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = weight_tensor._act_per_tensor_scale
input_tensor = NVFP4Tensor.to_nvfp4(
input_tensor,
block_size=k.block_size,
per_tensor_scale=weight_tensor._act_per_tensor_scale,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=k.is_swizzled_scales,
use_triton_kernel=k.use_triton_kernel,
)
Expand All @@ -672,10 +678,15 @@ def nvfp4_mm(func, types, args, kwargs):
else:
if not isinstance(input_tensor, NVFP4Tensor):
k = weight_tensor.act_quant_kwargs
if k.use_dynamic_per_tensor_scale:
tensor_amax = torch.max(torch.abs(input_tensor))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = weight_tensor._act_per_tensor_scale
input_tensor = NVFP4Tensor.to_nvfp4(
input_tensor,
block_size=k.block_size,
per_tensor_scale=weight_tensor._act_per_tensor_scale,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=k.is_swizzled_scales,
use_triton_kernel=k.use_triton_kernel,
)
Expand All @@ -697,12 +708,18 @@ def nvfp4_addmm(func, types, args, kwargs):
else:
return torch.addmm(bias, input_tensor, weight_dequant)
else:
# TODO: refactor duplicate code
if not isinstance(input_tensor, NVFP4Tensor):
k = weight_tensor.act_quant_kwargs
if k.use_dynamic_per_tensor_scale:
tensor_amax = torch.max(torch.abs(input_tensor))
per_tensor_scale = per_tensor_amax_to_scale(tensor_amax)
else:
per_tensor_scale = weight_tensor._act_per_tensor_scale
input_tensor = NVFP4Tensor.to_nvfp4(
input_tensor,
block_size=k.block_size,
per_tensor_scale=weight_tensor._act_per_tensor_scale,
per_tensor_scale=per_tensor_scale,
is_swizzled_scales=k.is_swizzled_scales,
use_triton_kernel=k.use_triton_kernel,
)
Expand Down
13 changes: 6 additions & 7 deletions torchao/quantization/qat/fake_quantize_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,15 @@ def _infer_fake_quantize_configs(
activation_dtype=e4m3_dtype,
)
elif isinstance(base_config, NVFP4InferenceConfig):
# Note: today the PTQ config does not allow the user to specify
# `per_tensor_scales` due to serialization concerns. In the future
# we may add a way to compute these dynamically (for activations),
# but for now QAT will mimic the existing behavior of not having
# `per_tensor_scales` (subject to change)
if NVFP4MMConfig.DYNAMIC:
act_config = NVFP4FakeQuantizeConfig(False)
act_config = NVFP4FakeQuantizeConfig(
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
)
else:
act_config = None
weight_config = NVFP4FakeQuantizeConfig(False)
weight_config = NVFP4FakeQuantizeConfig(
use_per_tensor_scale=base_config.use_dynamic_per_tensor_scale
)
elif isinstance(base_config, Int8DynamicActivationIntxWeightConfig):
assert base_config.version >= 2, "Only version 2+ is supported"
assert base_config.intx_packing_format == "unpacked_to_int8", (
Expand Down
Loading