From f9ca2f86f9f6a8d11852736f1ca81ebf71506491 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 11:07:21 -0700 Subject: [PATCH 01/12] Update [ghstack-poisoned] --- .../prototype/mx_formats/test_nvfp4_tensor.py | 66 +++++++++++++++++++ torchao/prototype/mx_formats/nvfp4_tensor.py | 30 +++++++-- 2 files changed, 89 insertions(+), 7 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 62cd1b88ad..911166e875 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -21,6 +21,7 @@ per_tensor_amax_to_scale, unpack_uint4, ) +from torchao.prototype.mx_formats.utils import ceil_div from torchao.quantization.utils import compute_error from torchao.testing.utils import skip_if_rocm from torchao.utils import ( @@ -525,3 +526,68 @@ def test_nvfp4_to_copy(): assert x.act_quant_kwargs == y.act_quant_kwargs assert x.dtype == torch.float32 assert y.dtype == torch.bfloat16 + + +@pytest.mark.parametrize("transpose", [False, True]) +# @pytest.mark.parametrize("transpose", [True]) +# @pytest.mark.parametrize("transpose", [False]) +@pytest.mark.parametrize("use_triton_kernel", [False, True]) +# @pytest.mark.parametrize("use_triton_kernel", [False]) +# @pytest.mark.parametrize("use_triton_kernel", [True]) +@pytest.mark.parametrize("is_swizzled_scales", [False, True]) +# @pytest.mark.parametrize("is_swizzled_scales", [False]) +# @pytest.mark.parametrize("is_swizzled_scales", [True]) +@pytest.mark.parametrize( + "mk", + ( + (128, 64), + (128 + 16, 64), + (128, 64 + 16), + (128 + 16, 64 + 16), + ), +) +# @pytest.mark.parametrize("mk", ((128 + 16, 64),)) +def test_scale_shape_matches_qdata( + transpose, use_triton_kernel, is_swizzled_scales, mk +): + if use_triton_kernel and not is_swizzled_scales: + pytest.skip("triton kernel requires swizzled scales") + + M, K = mk + + block_size = 16 + + # TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles, + # to test the padding logic + # context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + x_hp = torch.randn(M, K, device="cuda") + x = NVFP4Tensor.to_nvfp4( + x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel + ) + + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + + orig_m = x_hp.shape[m_dim] + expected_padded_m = orig_m + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_m = ceil_div(orig_m, 128) * 32 + actual_padded_m = x._scale_e4m3.shape[m_dim] + assert expected_padded_m == actual_padded_m, ( + f"incompatible padded shape for dim {m_dim}: {expected_padded_m=}, {actual_padded_m=}, {x.shape}, {x._scale_e4m3.shape}" + ) + + orig_k = x_hp.shape[k_dim] + expected_padded_k = orig_k // block_size + if is_swizzled_scales: + # in swizzled nvfp4, a 128x128 data unpacked / 128x64 data packed maps to a 32x16 scale tile + expected_padded_k = ceil_div(orig_k // block_size, 4) * 16 + actual_padded_k = x._scale_e4m3.shape[k_dim] + + assert expected_padded_k == actual_padded_k, ( + f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}" + ) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index c22f7793bb..aefad7750e 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -170,6 +170,9 @@ def to_nvfp4( Returns: NVFP4Tensor: Quantized tensor in NVFP4 format """ + assert len(data_hp.shape) == 2, "unsupported" + M, K = data_hp.shape[0], data_hp.shape[1] + if use_triton_kernel: assert is_swizzled_scales, "Triton kernel only supports swizzled scales" assert data_hp.shape[1] % 16 == 0, ( @@ -181,12 +184,23 @@ def to_nvfp4( data_hp, block_size, per_tensor_scale ) if is_swizzled_scales: - M, K = data_hp.shape[0], data_hp.shape[1] scale_shape = (M, K // block_size) blockwise_scales = to_blocked( blockwise_scales.view(scale_shape) ).flatten() + if is_swizzled_scales: + # a 128x64 unpacked or 128x64 packed qdata tile corresponds + # to a swizzled 32x16 scale tile + scale_M = ceil_div(M, 128) * 32 + scale_K = ceil_div(K, 64) * 16 + else: + # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 + # scale element + scale_M = M + scale_K = K // block_size + blockwise_scales = blockwise_scales.view(scale_M, scale_K) + return NVFP4Tensor( data_lp, blockwise_scales, @@ -239,13 +253,13 @@ def get_hp_scales(self) -> torch.Tensor: is_transposed = self.qdata.stride(0) < self.qdata.stride(1) if is_transposed: M, K = self.shape[1], self.shape[0] + scale_e4m3 = self._scale_e4m3.t() else: M, K = self.shape[0], self.shape[1] + scale_e4m3 = self._scale_e4m3 if self._is_swizzled_scales: - scale_e4m3 = from_blocked(self._scale_e4m3, M, K // self._block_size) - else: - scale_e4m3 = self._scale_e4m3 + scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size) return ( scale_e4m3.to(self._orig_dtype) @@ -537,7 +551,7 @@ def nvfp4_t(func, types, args, kwargs): old = args[0] new = NVFP4Tensor( old.qdata.t(), - old._scale_e4m3, + old._scale_e4m3.t(), old._block_size, old._orig_dtype, old._per_tensor_scale, @@ -576,7 +590,9 @@ def _addmm_nvfp4_dispatch( The only difference is whether bias is None or not. """ assert a.qdata.is_contiguous() + assert a._scale_e4m3.is_contiguous() assert b.qdata.t().is_contiguous() + assert b._scale_e4m3.t().is_contiguous() assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" @@ -591,9 +607,9 @@ def _addmm_nvfp4_dispatch( a_scale_blocked = to_blocked(a_scale) if b._is_swizzled_scales: - b_scale_blocked = b._scale_e4m3 # Already swizzled + b_scale_blocked = b._scale_e4m3.t() # Already swizzled else: - b_scale = b._scale_e4m3.view(N, K // b._block_size) + b_scale = b._scale_e4m3.t().view(N, K // b._block_size) b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D From 7da78268fbb1ecc065b583c790522f21228bb7ff Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 11:30:35 -0700 Subject: [PATCH 02/12] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/nvfp4_tensor.py | 36 +++++++++++++++----- torchao/prototype/mx_formats/utils.py | 18 ++++++++++ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index aefad7750e..bee4e6a632 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -24,7 +24,11 @@ tensor_size_fp4x2_to_hp, tensor_size_hp_to_fp4x2, ) -from torchao.prototype.mx_formats.utils import from_blocked, to_blocked +from torchao.prototype.mx_formats.utils import ( + from_blocked, + hp_data_dims_to_swizzled_scale_dims_nvfp4, + to_blocked, +) from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) @@ -190,15 +194,11 @@ def to_nvfp4( ).flatten() if is_swizzled_scales: - # a 128x64 unpacked or 128x64 packed qdata tile corresponds - # to a swizzled 32x16 scale tile - scale_M = ceil_div(M, 128) * 32 - scale_K = ceil_div(K, 64) * 16 + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(M, K) else: # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 # scale element - scale_M = M - scale_K = K // block_size + scale_M, scale_K = M, K // block_size blockwise_scales = blockwise_scales.view(scale_M, scale_K) return NVFP4Tensor( @@ -383,6 +383,9 @@ def nvfp4_slice(func, types, args, kwargs): M, K = x.shape[0], x.shape[1] + # the scale manipulations below assume a flattened scale + # TODO(future or this PR): update this + if x._is_swizzled_scales: scale_rows = M scale_cols = K // x._block_size @@ -421,7 +424,9 @@ def nvfp4_slice(func, types, args, kwargs): else None ) - sliced_scale = aten.slice.Tensor(x._scale_e4m3, 0, start_idx, end_idx, 1) + sliced_scale = aten.slice.Tensor( + x._scale_e4m3.flatten(), 0, start_idx, end_idx, 1 + ) sliced_data = aten.slice.Tensor(x.qdata, 0, start, end, step) elif dim == 1: @@ -476,7 +481,7 @@ def nvfp4_slice(func, types, args, kwargs): row_start = row_block * elements_per_row_block col_start = row_start + start_col_block * elements_per_block col_end = row_start + end_col_block * elements_per_block - slices_to_extract.append(x._scale_e4m3[col_start:col_end]) + slices_to_extract.append(x._scale_e4m3.flatten()[col_start:col_end]) # Concatenate all the slices sliced_scale = torch.cat(slices_to_extract, dim=0) @@ -529,6 +534,19 @@ def nvfp4_slice(func, types, args, kwargs): sliced_scale = sliced_scale.flatten() + # reshape at the end + sliced_M = sliced_data.shape[0] + # multiply by 2 to convert from bytes to num_elements + sliced_K = sliced_data.shape[1] * 2 + if x._is_swizzled_scales: + scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4(sliced_M, sliced_K) + else: + # a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1 + # scale element + scale_M = sliced_M + scale_K = sliced_K // x._block_size + sliced_scale = sliced_scale.view(scale_M, scale_K) + # Create result tensor result = NVFP4Tensor( sliced_data, diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 247b17d838..58ef62bf47 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import torch from torch.distributed._tensor import DTensor @@ -99,6 +101,22 @@ def from_blocked( return padded[:original_rows, :original_cols] +def hp_data_dims_to_swizzled_scale_dims_nvfp4( + hp_data_M, + hp_data_K, +) -> Tuple[int, int]: + """ + Given the `M` and `K` dimensions of a high precision contiguous tensor, + returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to + that tensor. + """ + # a 128x64 unpacked or 128x64 packed qdata tile corresponds + # to a swizzled 32x16 scale tile + scale_M = ceil_div(hp_data_M, 128) * 32 + scale_K = ceil_div(hp_data_K, 64) * 16 + return scale_M, scale_K + + def _to_blocked_single(scales: Tensor) -> Tensor: """Assume that we have a 128x4 block of scales in K Major order From fa40093744226b78d29f6cf4a1ce19e569b76a00 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 11:38:57 -0700 Subject: [PATCH 03/12] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/nvfp4_tensor.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index bee4e6a632..d1775d0812 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -383,8 +383,10 @@ def nvfp4_slice(func, types, args, kwargs): M, K = x.shape[0], x.shape[1] - # the scale manipulations below assume a flattened scale - # TODO(future or this PR): update this + # The scale manipulations below assume a flattened scale. For now, we + # flatten the scale, go through the calculations below, and then reshape + # it back to the format which matches the shape of `qdata`. + # TODO(future PR): update this if x._is_swizzled_scales: scale_rows = M From e10a16e62b4806e7bca8e5e29459a40d713cb5da Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 11:43:59 -0700 Subject: [PATCH 04/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_nvfp4_tensor.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 911166e875..ee3f1ad7ed 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -529,14 +529,8 @@ def test_nvfp4_to_copy(): @pytest.mark.parametrize("transpose", [False, True]) -# @pytest.mark.parametrize("transpose", [True]) -# @pytest.mark.parametrize("transpose", [False]) @pytest.mark.parametrize("use_triton_kernel", [False, True]) -# @pytest.mark.parametrize("use_triton_kernel", [False]) -# @pytest.mark.parametrize("use_triton_kernel", [True]) @pytest.mark.parametrize("is_swizzled_scales", [False, True]) -# @pytest.mark.parametrize("is_swizzled_scales", [False]) -# @pytest.mark.parametrize("is_swizzled_scales", [True]) @pytest.mark.parametrize( "mk", ( @@ -546,7 +540,6 @@ def test_nvfp4_to_copy(): (128 + 16, 64 + 16), ), ) -# @pytest.mark.parametrize("mk", ((128 + 16, 64),)) def test_scale_shape_matches_qdata( transpose, use_triton_kernel, is_swizzled_scales, mk ): @@ -557,9 +550,6 @@ def test_scale_shape_matches_qdata( block_size = 16 - # TODO(this PR): test larger tensors that don't exactly map to (128, 64) tiles, - # to test the padding logic - # context: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout x_hp = torch.randn(M, K, device="cuda") x = NVFP4Tensor.to_nvfp4( x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel From 9d0590be7e4007bf95945018b2d5dee2fea96c2c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 11:52:58 -0700 Subject: [PATCH 05/12] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 58ef62bf47..28a8526709 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -110,7 +110,7 @@ def hp_data_dims_to_swizzled_scale_dims_nvfp4( returns a 2d tuple of the dims of the swizzled nvfp4 scale corresponding to that tensor. """ - # a 128x64 unpacked or 128x64 packed qdata tile corresponds + # a 128x64 unpacked or 128x32 packed qdata tile corresponds # to a swizzled 32x16 scale tile scale_M = ceil_div(hp_data_M, 128) * 32 scale_K = ceil_div(hp_data_K, 64) * 16 From 08e9d137496d00b29cc5251cf008efff6b56aeda Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 12:31:14 -0700 Subject: [PATCH 06/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_nvfp4_tensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index ee3f1ad7ed..7ab37a0dba 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -528,6 +528,10 @@ def test_nvfp4_to_copy(): assert y.dtype == torch.bfloat16 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) @pytest.mark.parametrize("transpose", [False, True]) @pytest.mark.parametrize("use_triton_kernel", [False, True]) @pytest.mark.parametrize("is_swizzled_scales", [False, True]) @@ -543,6 +547,8 @@ def test_nvfp4_to_copy(): def test_scale_shape_matches_qdata( transpose, use_triton_kernel, is_swizzled_scales, mk ): + if use_triton_kernel and not is_sm_at_least_100(): + pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel") if use_triton_kernel and not is_swizzled_scales: pytest.skip("triton kernel requires swizzled scales") From 7b76009a767f667779b674e410267b4829beb779 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 1 Oct 2025 12:53:00 -0700 Subject: [PATCH 07/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 10 ++++++++++ torchao/testing/utils.py | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index c3227e211c..4ab65106d5 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -218,3 +218,13 @@ def test_narrow_similar_to_vllm(self): gemm_kernel_choice=MXGemmKernelChoice.EMULATED, ) self._test_narrow_similar_to_vllm(config) + + # TODO(next): make this test pass by enabling 3d NVFP4Tensor, currently a lot + # of places hardcode 2d + def test_nvfp4_quantize_3d_param_similar_to_vllm(self): + config = NVFP4InferenceConfig( + mm_config=NVFP4MMConfig.WEIGHT_ONLY, + use_triton_kernel=False, + use_dynamic_per_tensor_scale=False, + ) + self._test_quantize_3d_param_similar_to_vllm(config) diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index 5fec85fee6..7f694b56d3 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -625,6 +625,18 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig): f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}" ) + def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig): + # this happens when vLLM loads empty MoE weights and quantizes + # them + + dtype = torch.bfloat16 + with torch.device("meta"): + l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype) + l.weight = torch.nn.Parameter( + torch.randn(60, 2816, 2048, device="cuda", dtype=dtype) + ) + quantize_(l, config) + common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase) common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase) From 7ce6dcfbd224007841381b0eb599aa0844226e0d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 2 Oct 2025 04:48:34 -0700 Subject: [PATCH 08/12] Update [ghstack-poisoned] --- .../prototype/mx_formats/test_nvfp4_tensor.py | 27 ++++++++++++------- torchao/prototype/mx_formats/kernels.py | 8 ++++++ 2 files changed, 25 insertions(+), 10 deletions(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 1aa7112132..c32cd8c804 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -545,36 +545,43 @@ def test_nvfp4_to_copy(): @pytest.mark.parametrize("use_triton_kernel", [False, True]) @pytest.mark.parametrize("is_swizzled_scales", [False, True]) @pytest.mark.parametrize( - "mk", + "shape", ( (128, 64), (128 + 16, 64), (128, 64 + 16), (128 + 16, 64 + 16), + (1, 128, 64), ), ) def test_scale_shape_matches_qdata( - transpose, use_triton_kernel, is_swizzled_scales, mk + transpose, use_triton_kernel, is_swizzled_scales, shape ): if use_triton_kernel and not is_sm_at_least_100(): pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel") if use_triton_kernel and not is_swizzled_scales: pytest.skip("triton kernel requires swizzled scales") - M, K = mk - block_size = 16 - x_hp = torch.randn(M, K, device="cuda") + x_hp = torch.randn(*shape, device="cuda") x = NVFP4Tensor.to_nvfp4( x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel ) - m_dim, k_dim = 0, 1 - if transpose: - x_hp = x_hp.t() - x = x.t() - m_dim, k_dim = 1, 0 + if len(shape) == 2: + m_dim, k_dim = 0, 1 + if transpose: + x_hp = x_hp.t() + x = x.t() + m_dim, k_dim = 1, 0 + else: + assert len(shape) == 3, "unsupported" + m_dim, k_dim = 1, 2 + if transpose: + x_hp = x_hp.transpose(-2, -1) + x = x.transpose(-2, -1) + m_dim, k_dim = 2, 1 orig_m = x_hp.shape[m_dim] expected_padded_m = orig_m diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5811dd9d21..4a8c899d1c 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1391,6 +1391,10 @@ def triton_quantize_nvfp4( Since VLLM does not use dyanmo guards we need to make this a custom op to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` """ + # reshape to 2d + orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1] + x = x.reshape(-1, orig_N) + M, N = x.shape # assert M % 128 == 0 and N % 64 == 0 assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization" @@ -1431,6 +1435,10 @@ def triton_quantize_nvfp4( MASK_SCALES=MASK_SCALES, ) + # reshape back to original shape + scales = scales.view(*orig_leading_dims, -1, padded_cols) + xq = xq.view(*orig_leading_dims, -1, N // 2) + return scales, xq.view(torch.uint8) @triton_quantize_nvfp4.register_fake From dffb91c93b16cd636d090311b7aba672011f6d97 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 2 Oct 2025 04:50:43 -0700 Subject: [PATCH 09/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 2 -- test/prototype/mx_formats/test_nvfp4_tensor.py | 1 - 2 files changed, 3 deletions(-) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 4ab65106d5..c9e94c8da8 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -219,8 +219,6 @@ def test_narrow_similar_to_vllm(self): ) self._test_narrow_similar_to_vllm(config) - # TODO(next): make this test pass by enabling 3d NVFP4Tensor, currently a lot - # of places hardcode 2d def test_nvfp4_quantize_3d_param_similar_to_vllm(self): config = NVFP4InferenceConfig( mm_config=NVFP4MMConfig.WEIGHT_ONLY, diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index c32cd8c804..1b6450b898 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -58,7 +58,6 @@ def test_nvfp4_reconstruction(dtype, shape, use_per_tensor_scale): scale = None x_nvfp4 = NVFP4Tensor.to_nvfp4(x, per_tensor_scale=scale) - # import pdb; pdb.set_trace() x_reconstructed = x_nvfp4.to_dtype(dtype) def assert_sqnr_gt_threshold(orig, new, threshold): From 55c361fe388082261d52f265c1416a511d729e55 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 2 Oct 2025 04:58:08 -0700 Subject: [PATCH 10/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_nvfp4_tensor.py | 14 ++++++++++++++ torchao/prototype/mx_formats/nvfp4_tensor.py | 4 +++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 1b6450b898..0ec233ef33 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -602,3 +602,17 @@ def test_scale_shape_matches_qdata( assert expected_padded_k == actual_padded_k, ( f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}" ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+" +) +@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1))) +@pytest.mark.parametrize("is_swizzled_scales", [True, False]) +def test_3d_transpose(dims, is_swizzled_scales): + x_hp = torch.randn(2, 128, 256, device="cuda") + x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales) + x_hp_t = x_hp.transpose(dims[0], dims[1]) + x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1]) + assert x_hp_t.shape == x_nvfp4_t.shape diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index fe26ab60c6..043e1160e0 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -596,7 +596,9 @@ def nvfp4_t(func, types, args, kwargs): @implements([aten.transpose.int]) def nvfp4_transpose(func, types, args, kwargs): old, dim0, dim1 = args - assert dim0 == -2 and dim1 == -1, f"transpose unsupported for {dim0=} {dim1=}" + assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}" + valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1)) + assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}" new_qdata = func(old.qdata, dim0, dim1, **kwargs) new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs) new = NVFP4Tensor( From eb82d5f33f5261f3198982c817b2bb151ee25979 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 2 Oct 2025 05:30:57 -0700 Subject: [PATCH 11/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_nvfp4_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 0ec233ef33..773777d400 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -88,7 +88,6 @@ def assert_sqnr_gt_threshold(orig, new, threshold): x_nvfp4_t = x_nvfp4.t() x_t = x.t() else: - # TODO(before land): also test transpose dims (1, 2), (2, 1), (-1, -2) x_nvfp4_t = x_nvfp4.transpose(-2, -1) x_t = x.transpose(-2, -1) From baf7568e1b13aa0406b7dbf92f459d7d7a7eb12f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 2 Oct 2025 07:00:20 -0700 Subject: [PATCH 12/12] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_inference_workflow.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index c9e94c8da8..c61f973d03 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -219,6 +219,11 @@ def test_narrow_similar_to_vllm(self): ) self._test_narrow_similar_to_vllm(config) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.skipif( + not torch_version_at_least("2.8.0"), + reason="torch.compile requires PyTorch 2.8+", + ) def test_nvfp4_quantize_3d_param_similar_to_vllm(self): config = NVFP4InferenceConfig( mm_config=NVFP4MMConfig.WEIGHT_ONLY,