diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index f4f635af1b..09a1fd0b1e 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -257,12 +257,12 @@ def run( elif mode == "dim0_nvfp4": to_nvfp4_reference_c = torch.compile(to_nvfp4_reference) - y_d0, s_d0 = to_nvfp4_reference_c(x, use_triton_kernel=False) + y_d0, s_d0 = to_nvfp4_reference_c(x) for _ in range(2): - __ = to_nvfp4_reference_c(x, use_triton_kernel=False) + __ = to_nvfp4_reference_c(x) time_us = benchmark_cuda_function_in_microseconds( - lambda x: to_nvfp4_reference_c(x, use_triton_kernel=False), + lambda x: to_nvfp4_reference_c(x), x, ) assert y_d0.dtype == torch.uint8 diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 5889019af3..d5c94ceca7 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -359,13 +359,17 @@ def test_nvfp4_swizzled_scales_get_scales_method(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( - "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" + # "M", [128, 256, 512, 1024, 100, 200, 384], ids=lambda m: f"M{m}" + "M", [256, ], ids=lambda m: f"M{m}" ) -@pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}") +# @pytest.mark.parametrize("N", [64, 128, 256, 512, 32, 96, 160], ids=lambda n: f"N{n}") +@pytest.mark.parametrize("N", [128], ids=lambda n: f"N{n}") @pytest.mark.parametrize( - "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"] + # "use_per_tensor_scale", [False, True], ids=["block_scale", "tensor_scale"] + "use_per_tensor_scale", [False, ], ids=["block_scale"] ) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +# @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) +@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.skipif( not is_sm_at_least_100(), reason="requires sm100+ for raw intrinsics" ) @@ -394,7 +398,20 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): use_triton_kernel=True, ) - torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten()) + # print(nvfp4_triton.scale) + + s00 = nvfp4_pt.scale.reshape(2, -1, 16)[0].float() + s01 = nvfp4_pt.scale.reshape(2, -1, 16)[1].float() + s10 = nvfp4_triton.scale.reshape(2, -1, 16)[0].float() + s11 = nvfp4_triton.scale.reshape(2, -1, 16)[1].float() + # print(s00.sum(), s01.sum(), s10.sum(), s11.sum()) + + s0 = nvfp4_pt.scale.reshape(-1, 32 * 16).float().sum(dim=1) + s1 = nvfp4_triton.scale.reshape(-1, 32 * 16).float().sum(dim=1) + print('\n', s0) + print(s1) + + # breakpoint() pt_unpacked = unpack_uint4(nvfp4_pt.qdata) triton_unpacked = unpack_uint4(nvfp4_triton.qdata) torch.testing.assert_close( @@ -404,6 +421,8 @@ def test_triton_nvfp4_quantize_equivalence(M, N, use_per_tensor_scale, dtype): rtol=0, ) + torch.testing.assert_close(nvfp4_pt.scale.flatten(), nvfp4_triton.scale.flatten()) + x_pt_dequant = nvfp4_pt.dequantize(dtype) x_triton_dequant = nvfp4_triton.dequantize(dtype) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 173d99f746..1b447a771f 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1437,20 +1437,33 @@ def quantize_nvfp4_triton_kernel( s_ptr, stride_xm, stride_xn, + stride_sm, + stride_sn, M, N, USE_TENSOR_SCALE: tl.constexpr, MASK_SCALES: tl.constexpr, + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, ): + """ + 1. single block of data is shaped [128, 64] unpacked or [128, 32] packed + 2. corresponding single unswizzled block of scales is shaped [128, 4] + 3. corresponding single swizzles block of scales is shaped [32, 16] + """ + F4_E2M1_MAX = 6.0 F8E4M3_MAX = 448.0 E4M3_EPS = 1.5258789e-05 + NUM_ROW_INNER_TILES: tl.constexpr = ROW_TILE_SIZE // 128 + NUM_COL_INNER_TILES: tl.constexpr = COL_TILE_SIZE // 64 + pid_m = tl.program_id(1) pid_n = tl.program_id(0) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 64 + tl.arange(0, 64)[None, :] + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = pid_n * COL_TILE_SIZE + tl.arange(0, COL_TILE_SIZE)[None, :] if MASK_SCALES: mask = (offs_m < M) & (offs_n < N) other = 0.0 @@ -1459,11 +1472,11 @@ def quantize_nvfp4_triton_kernel( other = None x = tl.load( x_ptr + offs_m * stride_xm + offs_n * stride_xn, mask=mask, other=other - ) # [128, 64] - x_blocks = x.to(tl.float32).reshape(128, 4, 16) # [128, 4, 16] + ) # [ROW_TILE_SIZE, COL_TILE_SIZE] + x_blocks = x.to(tl.float32).reshape(ROW_TILE_SIZE, 4, 16) # [-1, 4, 16] # Compute block-wise scales - block_amax = tl.max(x_blocks.abs(), axis=2) # [128, 4] + block_amax = tl.max(x_blocks.abs(), axis=2) # [-1, 4] if USE_TENSOR_SCALE: # Two-level scaling: quantize block scales with per-tensor scale @@ -1501,21 +1514,37 @@ def quantize_nvfp4_triton_kernel( scales, 0.0, ) - packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16) - offs_m = tl.arange(0, 32)[:, None] - offs_n = tl.arange(0, 16)[None, :] + # packed_scales = scales.reshape(4, 32, 4).permute(1, 0, 2).reshape(32, 16) + packed_scales = scales.reshape(NUM_ROW_INNER_TILES, 4, 32, 4).permute(0, 2, 1, 3).reshape(NUM_ROW_INNER_TILES * 32, 16) + scale_offs_m = tl.arange(0, 32 * NUM_ROW_INNER_TILES)[:, None] + scale_offs_n = tl.arange(0, 16)[None, :] + # packed_scales = tl.arange(0, 32 * NUM_ROW_INNER_TILES * 16 * NUM_COL_INNER_TILES).reshape(NUM_ROW_INNER_TILES * 32, NUM_COL_INNER_TILES * 16).to(tl.float32) + + # TODO write me + scale_elements_per_outer_tile = (min(ROW_TILE_SIZE, M) // 128 * 32) * 16 + + # scale_offs = (scale_offs_m * 16 + scale_offs_n) + # TODO(next): debug here, offsets or masks are probably not correct here + scale_offs = (scale_offs_m * 16 + scale_offs_n) tl.store( s_ptr - + (pid_m * tl.num_programs(0) + pid_n) * (32 * 16) - + offs_m * 16 - + offs_n, + # + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16) + # + (pid_m * tl.num_programs(0) + pid_n) * (NUM_ROW_INNER_TILES * 32 * 16) + # + (pid_m * tl.num_programs(0) + pid_n) * (1 * 32 * 16) + + (pid_m * tl.num_programs(0) + pid_n) * scale_elements_per_outer_tile + + scale_offs, packed_scales, + mask=(scale_offs < scale_elements_per_outer_tile), ) # Convert to FP4 - x_fp4x2 = convert_fp32_to_fp4_packed(x_blocks.reshape(128, 32, 2).split()) - offs_m = pid_m * 128 + tl.arange(0, 128)[:, None] - offs_n = pid_n * 32 + tl.arange(0, 32)[None, :] + x_fp4x2 = convert_fp32_to_fp4_packed( + x_blocks.reshape(ROW_TILE_SIZE, 32, 2).split() + ) + offs_m = pid_m * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None] + offs_n = ( + pid_n * (COL_TILE_SIZE // 2) + tl.arange(0, COL_TILE_SIZE // 2)[None, :] + ) if MASK_SCALES: mask = (offs_m < M) & (offs_n < N // 2) else: @@ -1537,7 +1566,7 @@ def triton_quantize_nvfp4( Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scales tensor in swizzled layout. Note: - Since VLLM does not use dyanmo guards we need to make this a custom op + Since VLLM does not use dynamo guards we need to make this a custom op to avoid the triton kernel being invoked w/ the wrong use of `MASK_SCALES` """ # reshape to 2d @@ -1557,11 +1586,16 @@ def triton_quantize_nvfp4( # mask out scales to 0 if we are not aligned to 128 x 64 MASK_SCALES = M % 128 != 0 or N % 64 != 0 + # MASK_SCALES = True xq = x.new_empty(M, N // 2, dtype=torch.uint8) scales = x.new_empty(padded_rows, padded_cols, dtype=torch.float8_e4m3fn) + # scales.view(torch.uint8).fill_(45) + + ROW_TILE_SIZE = 128 * 2 + COL_TILE_SIZE = 64 - grid = (triton.cdiv(N, 64), triton.cdiv(M, 128)) + grid = (triton.cdiv(N, COL_TILE_SIZE), triton.cdiv(M, ROW_TILE_SIZE)) if per_tensor_scale is None: # Don't allocate tensor, we just steal this since it won't be used in kernel @@ -1578,10 +1612,14 @@ def triton_quantize_nvfp4( scales, x.stride(0), x.stride(1), + scales.stride(0), + scales.stride(1), M, N, USE_TENSOR_SCALE=use_tensor_scale, MASK_SCALES=MASK_SCALES, + ROW_TILE_SIZE=ROW_TILE_SIZE, + COL_TILE_SIZE=COL_TILE_SIZE, ) # reshape back to original shape