Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,16 @@ def test_narrow_similar_to_vllm(self):
gemm_kernel_choice=MXGemmKernelChoice.EMULATED,
)
self._test_narrow_similar_to_vllm(config)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"),
reason="torch.compile requires PyTorch 2.8+",
)
def test_nvfp4_quantize_3d_param_similar_to_vllm(self):
config = NVFP4InferenceConfig(
mm_config=NVFP4MMConfig.WEIGHT_ONLY,
use_triton_kernel=False,
use_dynamic_per_tensor_scale=False,
)
self._test_quantize_3d_param_similar_to_vllm(config)
60 changes: 44 additions & 16 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
(torch.float32, (64, 128), False),
(torch.bfloat16, (128, 256), False),
(torch.bfloat16, (64, 128), True),
(torch.bfloat16, (1, 32, 64), False),
],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -83,14 +84,20 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
f"Dtype mismatch: {x.dtype} vs {x_reconstructed.dtype}"
)

x_nvfp4_t = x_nvfp4.t()
if len(x.shape) == 2:
x_nvfp4_t = x_nvfp4.t()
x_t = x.t()
else:
x_nvfp4_t = x_nvfp4.transpose(-2, -1)
x_t = x.transpose(-2, -1)

x_reconstructed_t = x_nvfp4_t.to_dtype(dtype)
assert_sqnr_gt_threshold(x.t(), x_reconstructed_t, 8.0)
assert_sqnr_gt_threshold(x_t, x_reconstructed_t, 8.0)

assert x.t().shape == x_reconstructed_t.shape, (
assert x_t.shape == x_reconstructed_t.shape, (
f"Transpose shape mismatch: {x.t().shape} vs {x_reconstructed_t.shape}"
)
assert x.t().dtype == x_reconstructed_t.dtype, (
assert x_t.dtype == x_reconstructed_t.dtype, (
f"Transpose dtype mismatch: {x.t().dtype} vs {x_reconstructed_t.dtype}"
)

Expand All @@ -103,6 +110,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold):
(16, 32),
(64, 128),
(384, 128),
(1, 32, 64),
],
)
@pytest.mark.skipif(
Expand All @@ -115,8 +123,7 @@ def test_nvfp4_swizzled_scales_construction(is_swizzled_scales, shape):
that the _is_swizzled_scales flag is set correctly.
"""

M, K = shape
data = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
data = torch.randn(*shape, device="cuda", dtype=torch.bfloat16)

tensor = NVFP4Tensor.to_nvfp4(data, is_swizzled_scales=is_swizzled_scales)
assert tensor._is_swizzled_scales == is_swizzled_scales
Expand Down Expand Up @@ -536,36 +543,43 @@ def test_nvfp4_to_copy():
@pytest.mark.parametrize("use_triton_kernel", [False, True])
@pytest.mark.parametrize("is_swizzled_scales", [False, True])
@pytest.mark.parametrize(
"mk",
"shape",
(
(128, 64),
(128 + 16, 64),
(128, 64 + 16),
(128 + 16, 64 + 16),
(1, 128, 64),
),
)
def test_scale_shape_matches_qdata(
transpose, use_triton_kernel, is_swizzled_scales, mk
transpose, use_triton_kernel, is_swizzled_scales, shape
):
if use_triton_kernel and not is_sm_at_least_100():
pytest.skip("CUDA capability >= 10.0 required for nvfp4 triton kernel")
if use_triton_kernel and not is_swizzled_scales:
pytest.skip("triton kernel requires swizzled scales")

M, K = mk

block_size = 16

x_hp = torch.randn(M, K, device="cuda")
x_hp = torch.randn(*shape, device="cuda")
x = NVFP4Tensor.to_nvfp4(
x_hp, is_swizzled_scales=is_swizzled_scales, use_triton_kernel=use_triton_kernel
)

m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0
if len(shape) == 2:
m_dim, k_dim = 0, 1
if transpose:
x_hp = x_hp.t()
x = x.t()
m_dim, k_dim = 1, 0
else:
assert len(shape) == 3, "unsupported"
m_dim, k_dim = 1, 2
if transpose:
x_hp = x_hp.transpose(-2, -1)
x = x.transpose(-2, -1)
m_dim, k_dim = 2, 1

orig_m = x_hp.shape[m_dim]
expected_padded_m = orig_m
Expand All @@ -587,3 +601,17 @@ def test_scale_shape_matches_qdata(
assert expected_padded_k == actual_padded_k, (
f"incompatible padded shape for dim {k_dim}: {expected_padded_k}, {actual_padded_k=}, {x.shape}, {x._scale_e4m3.shape}"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not torch_version_at_least("2.8.0"), reason="NVFP4 requires PyTorch 2.8+"
)
@pytest.mark.parametrize("dims", ((1, 2), (2, 1), (-1, -2), (-2, -1)))
@pytest.mark.parametrize("is_swizzled_scales", [True, False])
def test_3d_transpose(dims, is_swizzled_scales):
x_hp = torch.randn(2, 128, 256, device="cuda")
x_nvfp4 = NVFP4Tensor.to_nvfp4(x_hp, is_swizzled_scales=is_swizzled_scales)
x_hp_t = x_hp.transpose(dims[0], dims[1])
x_nvfp4_t = x_nvfp4.transpose(dims[0], dims[1])
assert x_hp_t.shape == x_nvfp4_t.shape
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ def _nvfp4_inference_linear_transform(

weight = module.weight

if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0:
raise RuntimeError(
f"NVFP4 only supports weight shape divisible by 16, got {weight.shape}"
f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}"
)

if module.bias is not None and weight.dtype == torch.float32:
Expand Down
8 changes: 8 additions & 0 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,6 +1391,10 @@ def triton_quantize_nvfp4(
Since VLLM does not use dyanmo guards we need to make this a custom op
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
"""
# reshape to 2d
orig_leading_dims, _orig_M, orig_N = x.shape[:-2], x.shape[-2], x.shape[-1]
x = x.reshape(-1, orig_N)

M, N = x.shape
# assert M % 128 == 0 and N % 64 == 0
assert N % 16 == 0, "N must be divisible by 16 for NVFP4 quantization"
Expand Down Expand Up @@ -1431,6 +1435,10 @@ def triton_quantize_nvfp4(
MASK_SCALES=MASK_SCALES,
)

# reshape back to original shape
scales = scales.view(*orig_leading_dims, -1, padded_cols)
xq = xq.view(*orig_leading_dims, -1, N // 2)

return scales, xq.view(torch.uint8)

@triton_quantize_nvfp4.register_fake
Expand Down
16 changes: 14 additions & 2 deletions torchao/prototype/mx_formats/mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,12 @@ def tensor_size_hp_to_fp4x2(orig_size, is_contiguous):
if is_contiguous:
new_size = [*list(new_size[:-1]), new_size[-1] // 2]
else:
new_size = [new_size[0] // 2, *list(new_size[1:])]
if len(orig_size) == 2:
new_size = [new_size[0] // 2, *list(new_size[1:])]
else:
assert len(orig_size) == 3, "unsupported"
# only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders
new_size = [new_size[0], new_size[2] // 2, new_size[1]]
return new_size


Expand All @@ -435,10 +440,16 @@ def tensor_size_fp4x2_to_hp(orig_size, is_contiguous):
if is_contiguous:
new_size = [*list(new_size[:-1]), new_size[-1] * 2]
else:
new_size = [new_size[0] * 2, *list(new_size[1:])]
if len(orig_size) == 2:
new_size = [new_size[0] * 2, *list(new_size[1:])]
else:
assert len(orig_size) == 3, "unsupported"
# only supporting dim0, dim1, dim2 and dim0, dim2, dim1 orders
new_size = [new_size[0], new_size[2] * 2, new_size[1]]
return new_size


# TODO(future PR): fix this function for rank 3 and add tests
def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous):
new_size = orig_size
if is_contiguous:
Expand All @@ -448,6 +459,7 @@ def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous):
return new_size


# TODO(future PR): fix this function for rank 3 and add tests
def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous):
new_size = orig_size
if is_contiguous:
Expand Down
72 changes: 52 additions & 20 deletions torchao/prototype/mx_formats/nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import math
import sys
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -112,7 +113,7 @@ def __new__(

new_size = tensor_size_fp4x2_to_hp(
new_size,
qdata.stride(0) > qdata.stride(1),
qdata.stride(-2) > qdata.stride(-1),
)

self = torch.Tensor._make_wrapper_subclass(
Expand Down Expand Up @@ -174,21 +175,21 @@ def to_nvfp4(
Returns:
NVFP4Tensor: Quantized tensor in NVFP4 format
"""
assert len(data_hp.shape) == 2, "unsupported"
M, K = data_hp.shape[0], data_hp.shape[1]
assert len(data_hp.shape) in (2, 3), "unsupported"
leading_dims, M, K = data_hp.shape[:-2], data_hp.shape[-2], data_hp.shape[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just fyi, i think you can do:

*leading_dims, M, K = data_hp.shape


if use_triton_kernel:
assert is_swizzled_scales, "Triton kernel only supports swizzled scales"
assert data_hp.shape[1] % 16 == 0, (
f"Triton kernel requires K (dim 1) to be divisible by 16, got {data_hp.shape[1]}"
assert K % 16 == 0, (
f"Triton kernel requires K (dim -1) to be divisible by 16, got {K}"
)
blockwise_scales, data_lp = triton_quantize_nvfp4(data_hp, per_tensor_scale)
else:
blockwise_scales, data_lp = nvfp4_quantize(
data_hp, block_size, per_tensor_scale
)
if is_swizzled_scales:
scale_shape = (M, K // block_size)
scale_shape = (math.prod(leading_dims) * M, K // block_size)
blockwise_scales = to_blocked(
blockwise_scales.view(scale_shape)
).flatten()
Expand All @@ -199,7 +200,7 @@ def to_nvfp4(
# a 1x16 unpacked or 1x8 packed qdata tile corresponds to 1
# scale element
scale_M, scale_K = M, K // block_size
blockwise_scales = blockwise_scales.view(scale_M, scale_K)
blockwise_scales = blockwise_scales.view(*leading_dims, scale_M, scale_K)

return NVFP4Tensor(
data_lp,
Expand All @@ -225,22 +226,26 @@ def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
Returns:
torch.Tensor: Dequantized tensor in the target dtype
"""
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
if is_transposed:
M, K = self.shape[1], self.shape[0]
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
else:
M, K = self.shape[0], self.shape[1]
data = self.qdata.t() if is_transposed else self.qdata
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
data = self.qdata.transpose(-2, -1) if is_transposed else self.qdata
data_unpacked = unpack_uint4(data.contiguous().view(torch.uint8))
data_f32 = f4_unpacked_to_f32(data_unpacked)

data_f32 = data_f32.view(M, K // self._block_size, self._block_size)
scale_e4m3_reshaped = self.get_hp_scales().view(M, K // self._block_size, 1)
data_f32 = data_f32.view(
*leading_dims, M, K // self._block_size, self._block_size
)
scale_e4m3_reshaped = self.get_hp_scales().view(
*leading_dims, M, K // self._block_size, 1
)
data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32)
result = data_scaled.view(M, K).to(target_dtype)
result = data_scaled.view(*leading_dims, M, K).to(target_dtype)

if is_transposed:
result = result.t()
result = result.transpose(-2, -1)

return result

Expand All @@ -250,16 +255,18 @@ def get_hp_scales(self) -> torch.Tensor:
Returns:
torch.Tensor: Scales of the NVFP4Tensor
"""
is_transposed = self.qdata.stride(0) < self.qdata.stride(1)
is_transposed = self.qdata.stride(-2) < self.qdata.stride(-1)
if is_transposed:
M, K = self.shape[1], self.shape[0]
scale_e4m3 = self._scale_e4m3.t()
leading_dims, M, K = self.shape[:-2], self.shape[-1], self.shape[-2]
scale_e4m3 = self._scale_e4m3.transpose(-2, -1)
else:
M, K = self.shape[0], self.shape[1]
leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1]
scale_e4m3 = self._scale_e4m3

if self._is_swizzled_scales:
scale_e4m3 = from_blocked(scale_e4m3, M, K // self._block_size)
scale_e4m3 = from_blocked(
scale_e4m3, math.prod(leading_dims) * M, K // self._block_size
)

return (
scale_e4m3.to(self._orig_dtype)
Expand Down Expand Up @@ -380,6 +387,9 @@ def nvfp4_slice(func, types, args, kwargs):
raise ValueError("Only support aten.slice with step=1")

assert x.qdata.is_contiguous(), "Only support contiguous data for now"
assert len(x.shape) == 2, (
f"only rank 2 is supported for slice, got rank {len(x.shape)}"
)

M, K = x.shape[0], x.shape[1]

Expand Down Expand Up @@ -583,6 +593,28 @@ def nvfp4_t(func, types, args, kwargs):
return new


@implements([aten.transpose.int])
def nvfp4_transpose(func, types, args, kwargs):
old, dim0, dim1 = args
assert len(old.shape) == 3, f"unsupported rank {len(old.shape)}"
valid_3d_dims = ((1, 2), (2, 1), (-1, -2), (-2, -1))
assert (dim0, dim1) in valid_3d_dims, f"transpose unsupported for {dim0=} {dim1=}"
new_qdata = func(old.qdata, dim0, dim1, **kwargs)
new_scale = func(old._scale_e4m3, dim0, dim1, **kwargs)
new = NVFP4Tensor(
new_qdata,
new_scale,
old._block_size,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would block size change with transpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently block_size is an integer for this tensor, 16 for NVFP4. If we change it to a multidimensional block, we'd have to update this code.

old._orig_dtype,
old._per_tensor_scale,
old._act_per_tensor_scale,
old._is_swizzled_scales,
old.use_triton_kernel,
old.act_quant_kwargs,
)
return new


@implements([aten.view.default])
def nvfp4_view_op(func, types, args, kwargs):
data = args[0].qdata
Expand Down
12 changes: 12 additions & 0 deletions torchao/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,18 @@ def _test_narrow_similar_to_vllm(self, config: AOBaseConfig):
f"shape mismatch: {orig_attr.shape} vs {new_attr.shape}"
)

def _test_quantize_3d_param_similar_to_vllm(self, config: AOBaseConfig):
# this happens when vLLM loads empty MoE weights and quantizes
# them

dtype = torch.bfloat16
with torch.device("meta"):
l = torch.nn.Linear(1024, 1024, device="cuda", dtype=dtype)
l.weight = torch.nn.Parameter(
torch.randn(60, 2816, 2048, device="cuda", dtype=dtype)
)
quantize_(l, config)


common_utils.instantiate_parametrized_tests(TorchAOBasicTestCase)
common_utils.instantiate_parametrized_tests(TorchAOCompileTestCase)
Expand Down
Loading