diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 240e8eea49..a8b1991f01 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -492,8 +492,9 @@ def test_triton_mxfp8_dim1_randn(M, K): ) @pytest.mark.parametrize("M", (256, 2048, 131072)) @pytest.mark.parametrize("K", (256, 5120, 7168)) -def test_triton_mxfp8_dim0_randn(M, K): - x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) +def test_triton_mxfp8_dim0_randn(M, K, orig_dtype): + x = torch.randn(M, K, dtype=orig_dtype, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) @@ -521,8 +522,9 @@ def test_triton_mxfp8_dim0_zeros(): ) @pytest.mark.parametrize("M", (256, 2048, 131072)) @pytest.mark.parametrize("K", (256, 5120, 7168)) -def test_triton_mxfp8_dequant_dim0(M, K): - x = torch.zeros(M, K, dtype=torch.bfloat16, device="cuda") +@pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) +def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): + x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") block_size = 32 x_data, x_scales = triton_to_mxfp8_dim0_reference(x, block_size=32) hp_ref = to_dtype( @@ -530,9 +532,9 @@ def test_triton_mxfp8_dequant_dim0(M, K): x_scales, torch.float8_e4m3fn, block_size, - torch.bfloat16, + orig_dtype, ) - hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, torch.bfloat16, block_size) + hp_t = triton_mxfp8_dequant_dim0(x_data, x_scales, orig_dtype, block_size) torch.testing.assert_close(hp_t, hp_ref, rtol=0, atol=0) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 31d9e96f4b..cf6177179c 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1279,12 +1279,13 @@ def triton_to_mxfp8_dim1_reference( scale_e8m0_dim1.unsqueeze(-1), ) + @triton_op("torchao::triton_mxfp8_dequant_dim0", mutates_args={}) def triton_mxfp8_dequant_dim0( e4m3_data: torch.Tensor, e8m0_scales: torch.Tensor, out_dtype: torch.dtype, scale_block_size: int = 32, - ) -> None: + ) -> torch.Tensor: assert scale_block_size == 32, "scale_block_size must be 32 for now" assert out_dtype in (torch.bfloat16, torch.float32), ( "out_dtype must be bf16 or fp32" @@ -1300,7 +1301,7 @@ def triton_mxfp8_dequant_dim0( triton.cdiv(e4m3_data.shape[0], META["ROW_TILE_SIZE"]), triton.cdiv(e4m3_data.shape[1], META["COL_TILE_SIZE"]), ) - _dequant_mxfp8_kernel[grid]( + wrap_triton(_dequant_mxfp8_kernel)[grid]( e4m3_data, e8m0_scales.to(torch.uint8), out_buffer, @@ -1371,8 +1372,8 @@ def _dequant_mxfp8_kernel( @triton.jit def _e8m0_to_fp32(scale_e8m0): - e8m0_exponent_bias = 127 e8m0_nan_val = 255 + e8m0_exponent_bias = 127 s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias s_fp = tl.exp2(s_offset.to(tl.float32)) s_fp = tl.where(scale_e8m0 != e8m0_nan_val, s_fp, float("nan"))