Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@
import torchao
from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
)
from torchao.prototype.mx_formats.utils import to_blocked
from torchao.quantization.quant_api import (
Expand Down Expand Up @@ -445,8 +444,7 @@ def run(
kernel_preference=KernelPreference.AUTO,
)
elif recipe_name == "nvfp4":
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=False,
)
else:
Expand Down
42 changes: 22 additions & 20 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
NVFP4WeightOnlyConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference
Expand Down Expand Up @@ -138,9 +138,7 @@ def test_inference_workflow_mx(
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("compile", [True, False])
@pytest.mark.parametrize(
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
)
@pytest.mark.parametrize("quant_type", ["dynamic", "weight_only"])
@pytest.mark.parametrize("inpt_dtype", [torch.bfloat16, torch.float32])
@pytest.mark.parametrize("use_triton_kernel", [True, False])
@pytest.mark.parametrize("use_dynamic_per_tensor_scale", [True, False])
Expand All @@ -164,7 +162,7 @@ def test_inference_workflow_mx(
def test_inference_workflow_nvfp4(
bias: bool,
compile: bool,
mm_config: NVFP4MMConfig,
quant_type: str,
inpt_dtype: torch.dtype,
use_triton_kernel: bool,
use_dynamic_per_tensor_scale: bool,
Expand All @@ -177,14 +175,16 @@ def test_inference_workflow_nvfp4(
Tests both DYNAMIC and WEIGHT_ONLY mm_config modes
"""
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
if quant_type == "dynamic" and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")

if bias and inpt_dtype == torch.float32:
pytest.xfail("Bias is not supported when module weight is in fp32")

if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
if quant_type == "weight_only" and compile:
pytest.skip("TODO: weight_only quant currently errors w/ compile")
if quant_type == "weight_only" and use_triton_kernel:
pytest.skip("unsupported configuration")

if use_inference_mode and (
shapes != (128, 64, 256) or inpt_dtype != torch.bfloat16 or use_triton_kernel
Expand All @@ -200,11 +200,15 @@ def test_inference_workflow_nvfp4(
m = nn.Linear(in_features, out_features, bias=bias, dtype=inpt_dtype, device="cuda")
m_mx = copy.deepcopy(m)

config = NVFP4InferenceConfig(
mm_config=mm_config,
use_triton_kernel=use_triton_kernel,
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
)
if quant_type == "dynamic":
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_triton_kernel=use_triton_kernel,
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
)
else:
config = NVFP4WeightOnlyConfig(
use_dynamic_per_tensor_scale=use_dynamic_per_tensor_scale,
)
quantize_(m_mx, config=config)

if compile:
Expand All @@ -216,7 +220,7 @@ def test_inference_workflow_nvfp4(

y_ref = m(x)

if use_triton_kernel and mm_config != NVFP4MMConfig.WEIGHT_ONLY:
if use_triton_kernel and quant_type == "dynamic":
with cuda_kernel_profiler("quantize_nvfp4_triton_kernel") as result:
y_mx = m_mx(x)
assert result["found"], "Expected quantize_nvfp4 kernel to be found"
Expand All @@ -229,14 +233,14 @@ def test_inference_workflow_nvfp4(

sqnr = compute_error(y_ref, y_mx)

if mm_config == NVFP4MMConfig.WEIGHT_ONLY:
if quant_type == "weight_only":
SQNR_THRESHOLD = 18.0
else:
SQNR_THRESHOLD = 15.0

assert y_mx.dtype == inpt_dtype, f"Got {y_mx.dtype} for inpt_dtype={inpt_dtype}"
assert sqnr >= SQNR_THRESHOLD, (
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, mm_config={mm_config}"
f"Got a sqnr of {sqnr} for NVFP4 recipe with bias={bias}, {quant_type=}"
)


Expand Down Expand Up @@ -273,9 +277,7 @@ def test_narrow_similar_to_vllm(self):
reason="torch.compile requires PyTorch 2.8+",
)
def test_nvfp4_quantize_3d_param_similar_to_vllm(self):
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
use_triton_kernel=False,
config = NVFP4WeightOnlyConfig(
use_dynamic_per_tensor_scale=False,
)
self._test_quantize_3d_param_similar_to_vllm(config)
6 changes: 2 additions & 4 deletions test/prototype/mx_formats/test_mx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference
Expand Down Expand Up @@ -48,8 +47,7 @@ def test_serialization(recipe_name):
)
else:
assert recipe_name == "nvfp4", "unsupported"
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_triton_kernel=False,
use_dynamic_per_tensor_scale=False,
)
Expand Down
18 changes: 8 additions & 10 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from torchao.prototype.mx_formats.constants import (
F4_E2M1_MAX,
)
from torchao.prototype.mx_formats.inference_workflow import (
NVFP4MMConfig,
)
from torchao.prototype.mx_formats.nvfp4_tensor import (
NVFP4Tensor,
QuantizeTensorToNVFP4Kwargs,
Expand Down Expand Up @@ -422,7 +419,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)
@pytest.mark.parametrize("use_gelu", [True, False])
@pytest.mark.parametrize(
"mm_config", [NVFP4MMConfig.DYNAMIC, NVFP4MMConfig.WEIGHT_ONLY]
"quant_type",
["dynamic", "weight_only"],
)
@pytest.mark.parametrize("compile", [False])
@pytest.mark.parametrize("bias", [True, False])
Expand All @@ -448,22 +446,22 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
)
def test_nvfp4_matmul_with_amax(
use_gelu: bool,
mm_config: NVFP4MMConfig,
quant_type: str,
compile: bool,
bias: bool,
inpt_dtype: torch.dtype,
use_triton_kernel: bool,
shapes: tuple,
):
# DYNAMIC mode requires SM100+, but WEIGHT_ONLY works on older GPUs
if mm_config == NVFP4MMConfig.DYNAMIC and not is_sm_at_least_100():
if quant_type == "dynamic" and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for DYNAMIC float4 gemm")

if bias and inpt_dtype == torch.float32:
pytest.xfail("Bias is not supported when module weight is in fp32")

if mm_config == NVFP4MMConfig.WEIGHT_ONLY and compile:
pytest.skip("TODO: NVFP4MMConfig.WEIGHT_ONLY currently errors w/ compile")
if quant_type == "weight_only" and compile:
pytest.skip("TODO: weight_only currently errors w/ compile")

m, k, n = shapes

Expand All @@ -483,7 +481,7 @@ def test_nvfp4_matmul_with_amax(
a_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(A)))
b_scale = per_tensor_amax_to_scale(torch.amax(torch.abs(B)))
act_quant_kwargs = None
if mm_config == NVFP4MMConfig.DYNAMIC:
if quant_type == "dynamic":
act_quant_kwargs = QuantizeTensorToNVFP4Kwargs()
A_nvfp4 = NVFP4Tensor.to_nvfp4(
A,
Expand All @@ -509,7 +507,7 @@ def test_nvfp4_matmul_with_amax(
sqnr = compute_error(C_ref, C_nvfp4)
SQNR_THRESHOLD = 16.0
assert sqnr >= SQNR_THRESHOLD, (
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, mm_config={mm_config}, compile={compile}, bias={bias}"
f"SQNR {sqnr:.2f} < {SQNR_THRESHOLD}, use_gelu={use_gelu}, {quant_type=}, compile={compile}, bias={bias}"
)


Expand Down
16 changes: 10 additions & 6 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2082,13 +2082,15 @@ def test_infer_int4_weight_only_config(self):
def test_quantize_api_nvfp4(self, use_per_tensor_scale: bool):
"""
Test the following:
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="prepare"))
quantize_(model, QATConfig(NVFP4InferenceConfig(), step="convert"))
quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="prepare"))
quantize_(model, QATConfig(NVFP4DynamicActivationNVFP4WeightConfig(), step="convert"))
"""
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig

self._test_quantize_api_against_ptq(
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=use_per_tensor_scale
),
target_prepare_sqnr=float("inf"),
target_convert_sqnr=float("inf"),
)
Expand All @@ -2100,15 +2102,17 @@ def test_qat_nvfp4(self, use_per_tensor_scale: bool):
"""
Test QAT with `NVFP4FakeQuantizeConfig`.
"""
from torchao.prototype.mx_formats import NVFP4InferenceConfig
from torchao.prototype.mx_formats import NVFP4DynamicActivationNVFP4WeightConfig
from torchao.prototype.qat import NVFP4FakeQuantizeConfig

torch.manual_seed(self.SEED)
m = M().cuda()
baseline_model = copy.deepcopy(m)
quantize_(
baseline_model,
NVFP4InferenceConfig(use_dynamic_per_tensor_scale=use_per_tensor_scale),
NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=use_per_tensor_scale
),
)
qat_config = QATConfig(
activation_config=NVFP4FakeQuantizeConfig(use_per_tensor_scale),
Expand Down
36 changes: 23 additions & 13 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
NVFP4WeightOnlyConfig,
)
from torchao.quantization.quantize_.common import KernelPreference

Expand All @@ -129,6 +129,27 @@ quantize_(m_mxfp8, config=config)
m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True)
y_mxfp8 = m_mxfp8(x)

# nvfp4 dynamic quant

m_nvfp4 = copy.deepcopy(m)
config = NVFP4DynamicActivationNVFP4WeightConfig(
use_dynamic_per_tensor_scale=True,
use_triton_kernel=True,
)
quantize_(m_nvfp4, config=config)
m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True)
y_nvfp4 = m_nvfp4(x)

# nvfp4 weight-only quant

m_nvfp4_wo = copy.deepcopy(m)
config = NVFP4WeightOnlyConfig(
use_dynamic_per_tensor_scale=True,
)
quantize_(m_nvfp4_wo, config=config)
m_nvfp4_wo = torch.compile(m_nvfp4_wo, fullgraph=True)
y_nvfp4 = m_nvfp4_wo(x)

# mxfp4

m_mxfp4 = copy.deepcopy(m)
Expand All @@ -140,17 +161,6 @@ config = MXDynamicActivationMXWeightConfig(
quantize_(m_mxfp4, config=config)
m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True)
y_mxfp4 = m_mxfp4(x)

# nvfp4

m_nvfp4 = copy.deepcopy(m)
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.DYNAMIC,
use_dynamic_per_tensor_scale=True,
)
quantize_(m_nvfp4, config=config)
m_nvfp4 = torch.compile(m_nvfp4, fullgraph=True)
y_nvfp4 = m_nvfp4(x)
```

## MXTensor
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Note: Prototype and subject to change
from torchao.prototype.mx_formats.inference_workflow import (
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
NVFP4DynamicActivationNVFP4WeightConfig,
NVFP4WeightOnlyConfig,
)

# import mx_linear here to register the quantize_ transform logic
Expand All @@ -18,6 +18,6 @@
"MXLinearConfig",
"MXLinearRecipeName",
"MXDynamicActivationMXWeightConfig",
"NVFP4InferenceConfig",
"NVFP4MMConfig",
"NVFP4DynamicActivationNVFP4WeightConfig",
"NVFP4WeightOnlyConfig",
]
Loading