Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/mx_formats/cast_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ def run(

elif mode == "dim0_nvfp4":
to_nvfp4_reference_c = torch.compile(to_nvfp4_reference)
y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False)
y_d0, s_d0 = to_nvfp4_reference_c(x)

for _ in range(2):
__ = to_nvfp4_reference_c(x, use_triton_kernel=False)
__ = to_nvfp4_reference_c(x)
time_us = benchmark_cuda_function_in_microseconds(
lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False),
lambda x: to_nvfp4_reference_c(x),
x,
)
assert y_d0.dtype == torch.uint8
Expand Down
29 changes: 24 additions & 5 deletions test/prototype/mx_formats/test_nvfp4_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,13 +359,17 @@ def test_nvfp4_swizzled_scales_get_scales_method():

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}"
# "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}"
"M", [256, ], ids=lambda m: f"M{m}"
)
@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}")
# @pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}")
@pytest.mark.parametrize("N", [128], ids=lambda n: f"N{n}")
@pytest.mark.parametrize(
"use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"]
# "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"]
"use_per_tensor_scale", [False, ], ids=["block_scale"]
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
# @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"])
@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.skipif(
not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics"
)
Expand Down Expand Up @@ -394,7 +398,20 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
use_triton_kernel=True,
)

torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())
# print(nvfp4_triton.scale)

s00 = nvfp4_pt.scale.reshape(2, -1, 16)[0].float()
s01 = nvfp4_pt.scale.reshape(2, -1, 16)[1].float()
s10 = nvfp4_triton.scale.reshape(2, -1, 16)[0].float()
s11 = nvfp4_triton.scale.reshape(2, -1, 16)[1].float()
# print(s00.sum(), s01.sum(), s10.sum(), s11.sum())

s0 = nvfp4_pt.scale.reshape(-1, 32 * 16).float().sum(dim=1)
s1 = nvfp4_triton.scale.reshape(-1, 32 * 16).float().sum(dim=1)
print('\n', s0)
print(s1)

# breakpoint()
pt_unpacked = unpack_uint4(nvfp4_pt.qdata)
triton_unpacked = unpack_uint4(nvfp4_triton.qdata)
torch.testing.assert_close(
Expand All @@ -404,6 +421,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype):
rtol=0,
)

torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten())

x_pt_dequant = nvfp4_pt.dequantize(dtype)
x_triton_dequant = nvfp4_triton.dequantize(dtype)

Expand Down
70 changes: 54 additions & 16 deletions torchao/prototype/mx_formats/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,20 +1437,33 @@ def quantize_nvfp4_triton_kernel(
s_ptr,
stride_xm,
stride_xn,
stride_sm,
stride_sn,
M,
N,
USE_TENSOR_SCALE: tl.constexpr,
MASK_SCALES: tl.constexpr,
ROW_TILE_SIZE: tl.constexpr,
COL_TILE_SIZE: tl.constexpr,
):
"""
1. single block of data is shaped [128, 64] unpacked or [128, 32] packed
2. corresponding single unswizzled block of scales is shaped [128, 4]
3. corresponding single swizzles block of scales is shaped [32, 16]
"""

F4_E2M1_MAX = 6.0
F8E4M3_MAX = 448.0
E4M3_EPS = 1.5258789e-05

NUM_ROW_INNER_TILES: tl.constexpr = ROW_TILE_SIZE // 128
NUM_COL_INNER_TILES: tl.constexpr = COL_TILE_SIZE // 64

pid_m = tl.program_id(1)
pid_n = tl.program_id(0)

offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 64 + tl.arange(0, 64)[None, :]
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :]
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N)
other = 0.0
Expand All @@ -1459,11 +1472,11 @@ def quantize_nvfp4_triton_kernel(
other = None
x = tl.load(
x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other
) # [128, 64]
x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16]
) # [ROW_TILE_SIZE, COL_TILE_SIZE]
x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16]

# Compute block-wise scales
block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4]
block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4]

if USE_TENSOR_SCALE:
# Two-level scaling: quantize block scales with per-tensor scale
Expand Down Expand Up @@ -1501,21 +1514,37 @@ def quantize_nvfp4_triton_kernel(
scales,
0.0,
)
packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
offs_m = tl.arange(0, 32)[:, None]
offs_n = tl.arange(0, 16)[None, :]
# packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16)
packed_scales = scales.reshape(NUM_ROW_INNER_TILES, 4, 32, 4).permute(0, 2, 1, 3).reshape(NUM_ROW_INNER_TILES * 32, 16)
scale_offs_m = tl.arange(0, 32 * NUM_ROW_INNER_TILES)[:, None]
scale_offs_n = tl.arange(0, 16)[None, :]
# packed_scales = tl.arange(0, 32 * NUM_ROW_INNER_TILES * 16 * NUM_COL_INNER_TILES).reshape(NUM_ROW_INNER_TILES * 32, NUM_COL_INNER_TILES * 16).to(tl.float32)

# TODO write me
scale_elements_per_outer_tile = (min(ROW_TILE_SIZE, M) // 128 * 32) * 16

# scale_offs = (scale_offs_m * 16 + scale_offs_n)
# TODO(next): debug here, offsets or masks are probably not correct here
scale_offs = (scale_offs_m * 16 + scale_offs_n)
tl.store(
s_ptr
+ (pid_m * tl.num_programs(0) + pid_n) * (32 * 16)
+ offs_m * 16
+ offs_n,
# + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
# + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16)
# + (pid_m * tl.num_programs(0) + pid_n) * (1 * 32 * 16)
+ (pid_m * tl.num_programs(0) + pid_n) * scale_elements_per_outer_tile
+ scale_offs,
packed_scales,
mask=(scale_offs < scale_elements_per_outer_tile),
)

# Convert to FP4
x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split())
offs_m = pid_m * 128 + tl.arange(0, 128)[:, None]
offs_n = pid_n * 32 + tl.arange(0, 32)[None, :]
x_fp4x2 = convert_fp32_to_fp4_packed(
x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split()
)
offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
offs_n = (
pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :]
)
if MASK_SCALES:
mask = (offs_m < M) & (offs_n < N // 2)
else:
Expand All @@ -1537,7 +1566,7 @@ def triton_quantize_nvfp4(
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout.

Note:
Since VLLM does not use dyanmo guards we need to make this a custom op
Since VLLM does not use dynamo guards we need to make this a custom op
to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES`
"""
# reshape to 2d
Expand All @@ -1557,11 +1586,16 @@ def triton_quantize_nvfp4(

# mask out scales to 0 if we are not aligned to 128 x 64
MASK_SCALES = M % 128 != 0 or N % 64 != 0
# MASK_SCALES = True

xq = x.new_empty(M, N // 2, dtype=torch.uint8)
scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn)
# scales.view(torch.uint8).fill_(45)

ROW_TILE_SIZE = 128 * 2
COL_TILE_SIZE = 64

grid = (triton.cdiv(N, 64), triton.cdiv(M, 128))
grid = (triton.cdiv(N, COL_TILE_SIZE), triton.cdiv(M, ROW_TILE_SIZE))

if per_tensor_scale is None:
# Don't allocate tensor, we just steal this since it won't be used in kernel
Expand All @@ -1578,10 +1612,14 @@ def triton_quantize_nvfp4(
scales,
x.stride(0),
x.stride(1),
scales.stride(0),
scales.stride(1),
M,
N,
USE_TENSOR_SCALE=use_tensor_scale,
MASK_SCALES=MASK_SCALES,
ROW_TILE_SIZE=ROW_TILE_SIZE,
COL_TILE_SIZE=COL_TILE_SIZE,
)

# reshape back to original shape
Expand Down
Loading