diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index a9d8b18ae7..a38df8d340 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -13,6 +13,7 @@ from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.kernels import ( + triton_to_mxfp8_dim0, triton_to_mxfp8_dim1, ) from torchao.prototype.mx_formats.mx_tensor import to_mx @@ -97,6 +98,7 @@ def run( "dim0_mxfp8_floor", "dim0_mxfp4_floor", "dim0_mxfp8_rceil", + "dim0_mxfp8_triton_floor", "dim1_mxfp8_floor", "dim1_mxfp8_rceil", "dim1_mxfp8_triton_floor", @@ -222,6 +224,22 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim0_mxfp8_triton_floor": + y_d0, s_d0 = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE) + + for _ in range(2): + __ = triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: triton_to_mxfp8_dim0(x, inner_block_size=BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + assert y_d0.dtype == torch.float8_e4m3fn + assert s_d0.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim1_mxfp8_floor": to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 024586419a..7d47f2edef 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -37,6 +37,7 @@ pack_uint6, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, + triton_to_mxfp8_dim0, triton_to_mxfp8_dim1, triton_to_mxfp8_dim1_reference, unpack_uint4, @@ -431,6 +432,23 @@ def test_fp6_e3m2_pack_unpack(): assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) +def triton_to_mxfp8_dim0_reference( + x_hp: torch.Tensor, block_size +) -> tuple[torch.Tensor, torch.Tensor]: + """ + A reference version of `triton_to_mxfp8_dim0` for rowwise quantization. + """ + from torchao.prototype.mx_formats.mx_tensor import to_mx + + # cast across dim0 (rowwise) - no transpose needed + scale_e8m0_dim0, x_hp_d0_normalized = to_mx(x_hp, torch.float8_e4m3fn, block_size) + scale_e8m0_dim0 = scale_e8m0_dim0.view(torch.float8_e8m0fnu) + return ( + x_hp_d0_normalized, + scale_e8m0_dim0.unsqueeze(-1), + ) + + @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") @pytest.mark.skipif( not is_sm_at_least_89(), @@ -446,6 +464,21 @@ def test_triton_mxfp8_dim1_randn(M, K): torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif( + not is_sm_at_least_89(), + reason="float8 in triton requires CUDA capability 8.9 or greater", +) +@pytest.mark.parametrize("M", (256, 2048)) +@pytest.mark.parametrize("K", (256, 2048)) +def test_triton_mxfp8_dim0_randn(M, K): + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) + x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) + torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) + torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( "shape", diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5811dd9d21..45263c2884 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -1038,6 +1038,175 @@ def to_mxfp8_dim1_kernel( # TODO(future): mask this store tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) + @triton.autotune( + configs=_get_mxfp8_dim1_kernel_autotune_configs(), + key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"], + ) + @triton.jit + def to_mxfp8_dim0_kernel( + x_ptr, # pointer to input tensor + output_ptr, # pointer to output tensor (row-normalized) + row_scale_ptr, # pointer to store row-wise maximum absolute values + n_rows, # number of rows in the tensor + n_cols, # number of columns in the tensor + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, + INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX + ): + """ + Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity). + + This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization. + Instead of transposing and scaling across columns, this kernel scales across rows. + """ + + BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE + + # Get program ID + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + # Calculate starting row and column for this tile + start_row = pid_row * ROW_TILE_SIZE + start_col = pid_col * COL_TILE_SIZE + + # Create offsets for the block + row_offsets = tl.arange(0, ROW_TILE_SIZE) + col_offsets = tl.arange(0, COL_TILE_SIZE) + + # Compute global row/col positions + rows = start_row + row_offsets[:, None] + cols = start_col + col_offsets[None, :] + + # Create masks for out-of-bounds accesses + row_mask = rows < n_rows + col_mask = cols < n_cols + mask = row_mask & col_mask + + # Compute memory offsets for row-major layout (rows, cols) + row_major_offsets = (rows * n_cols + cols).to(tl.int32) + + # Load the entire block in a single operation + # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) + + # Reshape to inner tile size for rowwise scaling + # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE) + x_block_r = x_block.reshape( + ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE + ) + + # Calculate the absolute values of elements in the block + x_block_abs_r = tl.abs(x_block_r) + + # Find the maximum absolute value for each row (across columns) + # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) + row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1) + + # Divide each row by scale + # Broadcasting row_scale to match x_block's shape + # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE) + # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1) + row_normalized_r = x_block_r / row_scale_r[:, None] + + # Reshape back to original tile size + row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE) + + # Quantize to float8 + row_normalized = row_normalized.to(tl.float8e4nv) + + # Store the row-normalized result in row-major format + tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask) + + # reshape row_scale_e8m0_r for proper storage + # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) + row_scale_e8m0 = row_scale_e8m0_r.reshape(ROW_TILE_SIZE * BLOCKS_PER_COL_TILE) + + row_scale_start_offsets = ( + (pid_row * ROW_TILE_SIZE * (n_cols // COL_TILE_SIZE)) + * BLOCKS_PER_COL_TILE # number of blocks seen so far + + pid_col * BLOCKS_PER_COL_TILE # increment BLOCKS_PER_COL_TILE + ) + + row_scale_start_ptr = row_scale_ptr + row_scale_start_offsets + + # calculate row_scale_indices + row_scale_indices = tl.arange(0, ROW_TILE_SIZE * BLOCKS_PER_COL_TILE) + + # How many values are in all the other rows for this col_pid, need to jump + # over them for every BLOCKS_PER_COL_TILE values + jump_vals_per_row = (n_cols - COL_TILE_SIZE) // INNER_BLOCK_SIZE + + # example transformation (specifics depend on tile sizes): + # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] + row_scale_indices = row_scale_indices + ( + (row_scale_indices // BLOCKS_PER_COL_TILE) * jump_vals_per_row + ) + + # Store the scales + tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) + + @triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={}) + def triton_to_mxfp8_dim0( + x: torch.Tensor, inner_block_size: int = 32 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: + * `x` - input tensor, in row major memory layout + * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes + + Output: + * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise) + * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 + """ + assert x.is_contiguous(), "`x` must be contiguous" + assert inner_block_size <= 32 + + # Get tensor shape + n_rows, n_cols = x.shape + + # Masking of loads and stores is not well tested yet, so for now enforce + # shapes which do not need masking. Note that this condition depends on max values of + # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above. + # TODO(future): implement and test masking and remove this restriction + max_row_tile_size = 128 + max_col_tile_size = 128 + assert n_rows % max_row_tile_size == 0, "unsupported" + assert n_cols % max_col_tile_size == 0, "unsupported" + + # Create output tensors + output = torch.empty( + (n_rows, n_cols), dtype=torch.float8_e4m3fn, device=x.device + ) + + # Create scale tensors for rowwise scaling + row_scale = torch.empty( + (n_rows, n_cols // inner_block_size, 1), + dtype=torch.uint8, + device=x.device, + ) + + # Calculate grid dimensions based on tile size + grid = lambda META: ( + triton.cdiv(n_rows, META["ROW_TILE_SIZE"]), + triton.cdiv(n_cols, META["COL_TILE_SIZE"]), + ) + + # Launch the kernel + wrap_triton(to_mxfp8_dim0_kernel)[grid]( + x_ptr=x, + output_ptr=output, + row_scale_ptr=row_scale, + n_rows=n_rows, + n_cols=n_cols, + INNER_BLOCK_SIZE=inner_block_size, + ) + + return ( + output, + row_scale.view(torch.float8_e8m0fnu), + ) + @triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={}) def triton_to_mxfp8_dim1( x: torch.Tensor, inner_block_size: int = 32