Skip to content

[nvfp4] Make per_tensor_scale optional for triton kernel path#4188

Merged
jerryzh168 merged 15 commits intomainfrom
gh/jerryzh168/74/head
Apr 2, 2026
Merged

[nvfp4] Make per_tensor_scale optional for triton kernel path#4188
jerryzh168 merged 15 commits intomainfrom
gh/jerryzh168/74/head

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Mar 26, 2026

Stack from ghstack (oldest at bottom):

Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:

  • Remove assert per_tensor_scale is not None from to_nvfp4 triton branch
  • Update mslk_quantize_nvfp4 and its custom op to accept Optional[torch.Tensor],
    passing None through to MSLK (which treats it as global_scale=1.0)
  • Relax _addmm_nvfp4_dispatch to allow mixed per_tensor_scale states between
    operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v

Performance:

with global scale:

python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

without global scale:

python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4188

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit a205345 with merge base 0c29e81 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 26, 2026
jerryzh168 added a commit that referenced this pull request Mar 26, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 53b759f
Pull Request resolved: #4188
@jerryzh168 jerryzh168 added module: inference quantize_ api inference flow module: not user facing Use this tag if you don't want this PR to show up in release notes labels Mar 26, 2026
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 26, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: daffb99
Pull Request resolved: #4188
@jerryzh168 jerryzh168 requested review from jbschlosser and vkuzo March 26, 2026 23:16
Comment thread test/prototype/mx_formats/test_inference_workflow.py Outdated
Comment thread test/prototype/mx_formats/test_nvfp4_tensor.py
data_scaled = data_hp / block_scale_fp32.unsqueeze(-1)
# Multiply by reciprocal instead of dividing to match MSLK triton kernel
# numerics (global_scale=None treated as 1.0): x * (1.0 / fp8_scale)
data_scaled = data_hp * (1.0 / block_scale_fp32).unsqueeze(-1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there evidence that this is needed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, it's because the kernel equivalence numerical test fails

…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Mar 31, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 3d3489c
Pull Request resolved: #4188
@jerryzh168 jerryzh168 requested a review from vkuzo March 31, 2026 19:45
… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 1, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 32c52ee
Pull Request resolved: #4188
@jerryzh168
Copy link
Copy Markdown
Contributor Author

@claude review

@claude
Copy link
Copy Markdown

claude Bot commented Apr 1, 2026

Claude finished @jerryzh168's task in 5m 11s —— View job


Code Review: [nvfp4] Make per_tensor_scale optional for triton kernel path

  • Read PR diff and changed files
  • Review source changes (nvfp4_tensor.py, kernels.py)
  • Review test changes
  • Post review feedback

Overall this is a clean, well-scoped change. The core logic is straightforward: remove the assertion that blocked the triton kernel path when per_tensor_scale is None, and propagate Optional[torch.Tensor] through the call chain. A few observations:

torchao/prototype/mx_formats/kernels.py

  1. mslk_quantize_nvfp4 and _mslk_quantize_nvfp4_custom_op — looks good. The Optional[torch.Tensor] typing, the conditional .reciprocal(), and the register_fake default argument update are all correct.

  2. Line 1218 — trusting MSLK to handle None: The code passes global_scale=None directly to _mslk_triton_quantize_nvfp4(x, global_scale). This is fine given the PR description says MSLK commit c01f06c added this support, but it means this torchao change has a hard dependency on a specific MSLK nightly version. Consider adding a version guard or a clearer error message if someone has an older MSLK installed that doesn't support global_scale=None. As-is, they'd get an opaque error from inside MSLK.

torchao/prototype/mx_formats/nvfp4_tensor.py

  1. _addmm_nvfp4_dispatch mixed-scale logic (lines 465–476) — correct but could be simplified. The four-branch if/elif/elif/else is clear. An alternative is:

    if a_scale is not None and b_scale is not None:
        scale_result = a_scale * b_scale
    elif a_scale is not None or b_scale is not None:
        scale_result = a_scale if a_scale is not None else b_scale
    else:
        scale_result = None

    But the current version is readable enough — this is a style nit, not a correctness issue.

  2. Semantic question on mixed per_tensor_scale states: When one operand has per_tensor_scale and the other doesn't, this treats None as 1.0. This is mathematically correct, but it's an unusual configuration — it implies the activation and weight were quantized with different scaling strategies. As vkuzo noted in the review, is there evidence this mixed case is needed in practice? If not, a simpler approach would be to only support matching states (both or neither) and keep the assertion, which would catch accidental misconfigurations. If it is needed (as jerryzh168 indicated for the numerical equivalence test), then the current approach is fine.

Test changes

  1. test_inference_workflow.py:184 — leftover comment. The line # use_triton_kernel without per_tensor_scale is now supported (MSLK#233) is a dangling comment that doesn't guard or explain any remaining code. vkuzo already flagged this — I agree it should be removed. Fix this →

  2. test_nvfp4_matmul_optional_per_tensor_scale — well structured. The parametrization covers the three relevant combinations (a_scale_only, b_scale_only, no_scale), and the SQNR threshold of 16.0 is reasonable. Minor observations:

    • The (True, True) case is not explicitly tested here, though it's covered by existing tests elsewhere.
    • act_quant_kwargs = QuantizeTensorToNVFP4Kwargs() is only set on B_nvfp4 — this is intentional since B represents weights, but a brief comment could clarify.
  3. test_triton_nvfp4_quantize_equivalence — skip removal is correct. The use_per_tensor_scale=False case should now pass with the MSLK update.

Summary

The change is well-motivated by the performance data showing improved speedups without global scale (e.g., 3.65x vs 2.84x at 16K). The implementation is correct. Main actionable items:

  • Remove the dangling comment in test_inference_workflow.py:184 (per vkuzo's feedback)
  • Consider a version guard or error message for older MSLK versions that don't support global_scale=None

… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 1, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 53081a1
Pull Request resolved: #4188
… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 1, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 06fa736
Pull Request resolved: #4188
@jerryzh168
Copy link
Copy Markdown
Contributor Author

@claude review again

@claude
Copy link
Copy Markdown

claude Bot commented Apr 1, 2026

Claude finished @jerryzh168's task in 3m 51s —— View job


Code Review: [nvfp4] Make per_tensor_scale optional for triton kernel path

  • Read PR diff and changed files
  • Review source changes (nvfp4_tensor.py, kernels.py, utils.py)
  • Review test changes
  • Review benchmark changes
  • Post review feedback

This revision addresses the previous feedback well — the version guard for MSLK was added, vkuzo's comment about removing the b_has_scale parameter was applied, and the dangling comment in test_inference_workflow.py was removed. The change is clean and correct. A few observations on the updated code:

torchao/prototype/mx_formats/kernels.py

  1. Version guard (lines 1219–1226) — good addition. This directly addresses the concern from the previous review about older MSLK versions failing opaquely. The error message is clear and actionable. One minor note: is_mslk_version_at_least("1.1.0") will call _is_mslk_available() which does importlib.util.find_spec("mslk") and import mslk, and then line 1220 does import mslk again. Since Python caches module imports this is not a performance issue, but you could simplify by reusing the already-imported module:

    import mslk
    if global_scale is None and parse_version(mslk.__version__) < parse_version("1.1.0"):
        raise RuntimeError(...)

    This is purely a readability nit — the current code works fine.

  2. MSLK __version__ assumption: is_mslk_version_at_least assumes MSLK has a __version__ attribute with a major.minor.patch format. If an older nightly MSLK doesn't define __version__, this would raise an AttributeError rather than a helpful message. Worth considering a hasattr guard, though this may be low-risk if MSLK has always had __version__.

torchao/prototype/mx_formats/nvfp4_tensor.py

  1. nvfp4_quantize div→mul change (line 694): Changing data_hp / block_scale_fp32 to data_hp * (1.0 / block_scale_fp32) is a subtle but intentional numerics change to match the MSLK triton kernel behavior. The comment explains the rationale. This is the right approach for numerical equivalence testing — division and multiplication-by-reciprocal can differ at fp32 precision boundaries. Just be aware this changes the reference path's numerics slightly for the per_tensor_scale is None case. Since the test uses a 16.0 dB SQNR threshold (not bit-exact), this should be fine.

  2. _addmm_nvfp4_dispatch mixed-scale logic (lines 465–474) — correct. The three-way branching is clean and handles all cases. One thing to watch: when scale_result is a single tensor (the elif branch), should_add_bias_separately at line 484 will be True, meaning bias gets added outside _scaled_mm. This is the correct path since the per-tensor scale needs to be applied first.

torchao/utils.py

  1. is_mslk_version_at_least (lines 1280–1285) — clean utility. Follows the same pattern as torch_version_at_least. No issues.

Tests

  1. test_nvfp4_matmul_optional_per_tensor_scale — well structured. The test now correctly only varies a_has_scale (activation), keeping b_scale (weight) always set, which matches the real-world use case. The use_triton_kernel parametrization ensures both quantization paths are covered. The test docstring clearly states the intent.

  2. test_inference_workflow.py — skip removal is correct. The two lines if use_triton_kernel and not use_dynamic_per_tensor_scale: pytest.skip(...) were properly removed.

Benchmark changes

  1. float8_inference_roofline.py — adding nvfp4_no_global_scale recipe. The new recipe is threaded through all the right places: get_gemm_times, the scale setup, the _scaled_mm call path, and the run() config setup. The roofline_utils.py change to case "nvfp4_static" | "nvfp4_no_global_scale" is correct — no-global-scale has the same memory traffic profile as static scaling (single kernel for quantize + gemm).

Summary

The PR is in good shape. The main improvement since the last review is the MSLK version guard, which prevents opaque failures on older MSLK versions. The code changes are correct, well-documented, and the test coverage is appropriate.

Only minor nit remaining:

  • Consider a hasattr(mslk, '__version__') guard in is_mslk_version_at_least for robustness against MSLK installations that may not define __version__. This is low-priority if MSLK is known to always have it.

Comment thread torchao/prototype/mx_formats/kernels.py Outdated
)

if global_scale is None and not is_mslk_version_at_least("1.1.0"):
import mslk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for

Copy link
Copy Markdown
Contributor Author

@jerryzh168 jerryzh168 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mslk.__version__ that's used in L1224 I think

Comment thread torchao/prototype/mx_formats/kernels.py Outdated
import mslk

raise RuntimeError(
f"Optional global_scale support requires MSLK >= 1.1.0, "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this version matching the MSLK release corresponding to PyTorch 2.11.0, or is this a later version?

Copy link
Copy Markdown
Contributor Author

@jerryzh168 jerryzh168 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the matching version for torch 2.11.0, just released recently

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be less confusing to do it as follows

if global_scale is None:
    assert is_mslk_version_at_least("1.1.0"), "unsupported"
    ...

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, updated

Comment thread torchao/utils.py
def is_mslk_version_at_least(min_version: str) -> bool:
if not _is_mslk_available():
return False
import mslk
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this safe?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah _is_mslk_available() checks the availability of mslk library:

ao/torchao/utils.py

Lines 1270 to 1277 in 79159f2

def _is_mslk_available():
has_mslk = importlib.util.find_spec("mslk") is not None or is_fbcode()
if not has_mslk:
return False
import mslk # noqa: F401
return True

(we should remove import there actually, I can put up a follow up PR)

… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 2, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: f9a4c1f
Pull Request resolved: #4188
… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 2, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 6f4f6b4
Pull Request resolved: #4188
… triton kernel path"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
…ath"


Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK nightly installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "test_inference_workflow_nvfp4" -v
```

Performance:

with global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+git95281b63b
recipe_name             nvfp4
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 6.11413043478261e-13*K*M + 1.35869565217391e-13*M*floor(K/16))
   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.45
1   2048   2048   2048                      2.39            0.66
2   4096   4096   4096                      2.92            1.29
3   8192   8192   8192                      3.34            1.74
4  16384  16384  16384                      3.63            2.84

```

without global scale:
```
python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_no_global_scale --enable_fusion_modeling True --skip_printing_detailed_metrics True

Parameter               Value
----------------------  ------------------------
GPU                     NVIDIA GB200
torch version           2.12.0.dev20260316+cu128
torchao version         0.17.0+gitabb103d3b
recipe_name             nvfp4_no_global_scale
do_benchmarks           True
shape_gen_name          pow2
enable_fusion_modeling  True
op_name                 linear
MKN                     None None None
DHW                     None None None
kernel_size
stride                  1
padding                 0
bf16_gemm_time_sympy Max(2.0e-6, 1.13960113960114e-15*K*M*N, 2.71739130434783e-13*K*M + 2.71739130434783e-13*K*N + 2.71739130434783e-13*M*N)
bf16_ovhd_time_sympy Max(2.0e-6, 5.43478260869565e-13*K*M)
fp8_gemm_time_sympy Max(2.0e-6, 2.84900284900285e-16*K*M*N, 6.79347826086956e-14*K*M + 6.79347826086956e-14*K*N + 2.71739130434783e-13*M*N + 6.79347826086956e-14*floor(K*M/16 + K*N/16))
fp8_ovhd_time_sympy Max(2.0e-6, 3.39673913043478e-13*K*M + 1.35869565217391e-13*M*floor(K/16))

   fwd_M  fwd_K  fwd_N  r_fp8_gemm_and_ovhd_spdp  b_fp8_e2e_spdp
0   1024   1024   1024                      1.00            0.73
1   2048   2048   2048                      2.71            1.09
2   4096   4096   4096                      3.44            2.22
3   8192   8192   8192                      3.68            2.82
4  16384  16384  16384                      3.83            3.65

```

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 2, 2026
Summary:
MSLK now supports optional global scale in its triton quantize kernel
(MSLK#233, commit c01f06c). This change relaxes the corresponding
constraint in torchao so the triton kernel path can be used without
a per_tensor_scale (single-level block-wise scaling only).

Changes:
- Remove `assert per_tensor_scale is not None` from `to_nvfp4` triton branch
- Update `mslk_quantize_nvfp4` and its custom op to accept `Optional[torch.Tensor]`,
  passing `None` through to MSLK (which treats it as global_scale=1.0)
- Relax `_addmm_nvfp4_dispatch` to allow mixed per_tensor_scale states between
  operands (treat None as 1.0) instead of asserting both-or-neither

Test Plan:
Requires SM100+ GPU with MSLK installed.

```
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_triton_nvfp4_quantize_equivalence -v
python -m pytest test/prototype/mx_formats/test_nvfp4_tensor.py::test_nvfp4_matmul_optional_per_tensor_scale -v
python -m pytest test/prototype/mx_formats/test_inference_workflow.py::test_inference_workflow_nvfp4 -k "use_triton_kernel_True-use_dynamic_per_tensor_scale_False" -v
```

ghstack-source-id: 547bfc9
Pull Request resolved: #4188
@jerryzh168 jerryzh168 requested a review from vkuzo April 2, 2026 19:44
@jerryzh168 jerryzh168 changed the base branch from gh/jerryzh168/74/base to main April 2, 2026 21:48
@jerryzh168 jerryzh168 merged commit a302c10 into main Apr 2, 2026
36 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants