From 6a2088ad07b5272a4f0ddfefea6e2b9cb50862a3 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Fri, 17 Oct 2025 09:11:55 -0700 Subject: [PATCH] [mxfp8 moe training] simplify e8m0 -> fp32 calc stack-info: PR: https://github.com/pytorch/ao/pull/3201, branch: danielvegamyhre/stack/80 --- test/prototype/mx_formats/test_kernels.py | 14 ++++++++------ torchao/prototype/mx_formats/kernels.py | 7 ++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 240e8eea49..a8b1991f01 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -492,8 +492,9 @@ def test_triton_mxfp8_dim1_randn(M, K): ) @pytest.mark.parametrize("M", (256, 2048, 131072)) @pytest.mark.parametrize("K", (256, 5120, 7168)) -def test_triton_mxfp8_dim0_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) +def test_triton_mxfp8_dim0_randn(M, K, orig_dtype): + x = torch.randn(M, K, dtype=orig_dtype, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -521,8 +522,9 @@ def test_triton_mxfp8_dim0_zeros(): ) @pytest.mark.parametrize("M", (256, 2048, 131072)) @pytest.mark.parametrize("K", (256, 5120, 7168)) -def test_triton_mxfp8_dequant_dim0(M, K): - x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) +def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): + x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") block_size = 32 x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) hp_ref = to_dtype( @@ -530,9 +532,9 @@ def test_triton_mxfp8_dequant_dim0(M, K): x_scales, torch.float8_e4m3fn, block_size, - torch.bfloat16, + orig_dtype, ) - hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size) + hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size) torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 31d9e96f4b..cf6177179c 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1279,12 +1279,13 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1.unsqueeze(-1), ) + @triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={}) def triton_mxfp8_dequant_dim0( e4m3_data: torch.Tensor, e8m0_scales: torch.Tensor, out_dtype: torch.dtype, scale_block_size: int = 32, - ) -> None: + ) -> torch.Tensor: assert scale_block_size == 32, "scale_block_size must be 32 for now" assert out_dtype in (torch.bfloat16, torch.float32), ( "out_dtype must be bf16 or fp32" @@ -1300,7 +1301,7 @@ def triton_mxfp8_dequant_dim0( triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]), triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]), ) - _dequant_mxfp8_kernel[grid]( + wrap_triton(_dequant_mxfp8_kernel)[grid]( e4m3_data, e8m0_scales.to(torch.uint8), out_buffer, @@ -1371,8 +1372,8 @@ def _dequant_mxfp8_kernel( @triton.jit def _e8m0_to_fp32(scale_e8m0): - e8m0_exponent_bias = 127 e8m0_nan_val = 255 + e8m0_exponent_bias = 127 s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias s_fp = tl.exp2(s_offset.to(tl.float32)) s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))