From 56722ebbdee11f1f1f11ab83da4f096b1d020595 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Sat, 6 Sep 2025 19:12:12 -0700 Subject: [PATCH 1/5] [mxfp8 moe training] blocked scale conversion for LHS of 2d-2d grouped gemm --- .../benchmark_2d_3d_grouped_gemms.py | 2 +- ...chmark_2d_blocked_swizzle_scale_kernels.py | 2 +- test/prototype/moe_training/test_kernels.py | 49 +++- .../kernels/mxfp8_blocked_scales.py | 257 +++++++++++++++++- .../moe_training/kernels/mxfp8_gemms.py | 2 +- 5 files changed, 302 insertions(+), 10 deletions(-) diff --git a/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py b/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py index ef398ac553..1dc78ec0a5 100644 --- a/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py +++ b/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py @@ -231,7 +231,7 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: # Convert scales for each group to blocked format. Mg, K = A_fp8.shape A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d( - A_scales, offs, Mg, K + A_scales, offs, K ) B_scales_blocked = torch_to_blocked_per_group_3d(B_scales) diff --git a/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py b/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py index 1dc6ade1df..f1185bd533 100644 --- a/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py @@ -84,7 +84,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: # bench torch compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d) torch_out_scales, torch_group_offs = compiled_run_torch( - input_tensor, input_group_offsets, Mg, K + input_tensor, input_group_offsets, K ) torch_time_us = benchmark_cuda_function_in_microseconds( compiled_run_torch, diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 1cef8c0ed4..67cb57033e 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -23,9 +23,12 @@ ) from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( compute_per_group_blocked_scale_offsets, + compute_per_group_blocked_scale_offsets_2d2d_lhs, torch_to_blocked_per_group_2d, + torch_to_blocked_per_group_2d2d_lhs, torch_to_blocked_per_group_3d, triton_mx_block_rearrange_per_group_2d, + triton_mx_block_rearrange_per_group_2d2d_lhs, triton_mx_block_rearrange_per_group_3d, ) from torchao.prototype.moe_training.utils import ( @@ -227,7 +230,7 @@ def test_mxfp8_per_group_blocked_scales_2d( # torch reference ref_out_scales, _ = torch_to_blocked_per_group_2d( - e8m0_scales, input_group_offsets, m, k, block_size=block_size + e8m0_scales, input_group_offsets, k, block_size=block_size ) # triton kernel @@ -266,3 +269,47 @@ def test_mxfp8_per_group_blocked_scales_3d( assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( "blocked scales not equal" ) + + +@skip_if_rocm("ROCm enablement in progress") +@pytest.mark.parametrize("m,total_k,n_groups", [(256, 64, 2)]) +def test_mxfp8_per_group_blocked_scales_2d2d_lhs( + m: int, + total_k: int, + n_groups: int, +): + device = "cuda" + block_size = 32 + input_data = torch.randn(m, total_k, device=device) + e8m0_scales, _ = to_mx( + input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) + + # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + input_group_offsets //= block_size + + # torch reference + ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_per_group_2d2d_lhs( + e8m0_scales, + input_group_offsets, + ) + + # triton kernel + _, output_group_offsets = compute_per_group_blocked_scale_offsets_2d2d_lhs( + input_group_offsets + ) + assert torch.allclose(output_group_offsets, ref_start_cols_after_padding), ( + "output scale group start offsets not equal" + ) + triton_out_scales = triton_mx_block_rearrange_per_group_2d2d_lhs( + e8m0_scales, + input_group_offsets, + output_group_offsets, + ) + breakpoint() + assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( + "blocked scales not equal" + ) diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index 1febebbc7d..26d3b40170 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import triton import triton.language as tl @@ -8,8 +10,8 @@ def torch_to_blocked_per_group_2d( - x_scales: Tensor, group_offs: Tensor, Mg: int, K: int, block_size: int = 32 -) -> Tensor: + x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32 +) -> Tuple[Tensor, Tensor]: """ Convert scales to blocked format for a 2D tensor (input activations / token groups) @@ -58,6 +60,59 @@ def torch_to_blocked_per_group_2d( return blocked_scales, start_row_after_padding +def torch_to_blocked_per_group_2d2d_lhs( + x_scales: Tensor, group_offs: Tensor, block_size: int = 32 +) -> Tuple[Tensor, Tensor]: + """ + Convert scales to blocked format for a 2D tensor (input activations) when scaling along the contraction dimension. + + Args: + x_scales: Tensor with per group scales in blocked format concatenated into one tensor. + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the total_k dimension. + total_K: total size of all groups summed together + + Returns: + blocked_scales: Tensor + start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group. + """ + assert x_scales.ndim == 2, "x_scales must be 2D" + assert block_size == 32, "Only block_size=32 is supported for now" + blocked_scales_list = [] + start_col_after_padding_list = [0] + group_start_idx = 0 + for i, group_end_idx in enumerate(group_offs.tolist()): + group_size = group_end_idx - group_start_idx + prev_start_row_after_padding = start_col_after_padding_list[i] + if group_size == 0: + start_col_after_padding_list.append(prev_start_row_after_padding) + continue + + # Convert group scales to blocked format + group_scales = x_scales[:, group_start_idx:group_end_idx] + group_scales_blocked = to_blocked(group_scales) + blocked_scales_list.append(group_scales_blocked) + + # Calculate the start row after padding + cols_after_padding = ceil_div(group_size, 4) * 4 + new_start_col = prev_start_row_after_padding + cols_after_padding + start_col_after_padding_list.append(new_start_col) + + # Update next group start index + group_start_idx = group_end_idx + + M = x_scales.shape[0] + padded_M = ceil_div(M, 128) * 128 + # blocked_scales = torch.cat(blocked_scales_list, dim=0) + # blocked_scales = blocked_scales.reshape(padded_M, -1) + blocked_scales = torch.cat( + [s.reshape(padded_M, -1) for s in blocked_scales_list], dim=1 + ) + start_cols_after_padding = torch.tensor( + start_col_after_padding_list, device=x_scales.device, dtype=torch.int64 + ) + return blocked_scales, start_cols_after_padding + + def torch_to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor: """ Convert scales to blocked format for each group for a 3D tensor (expert weights) @@ -104,6 +159,32 @@ def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor): return group_sizes, starting_row_after_padding +def compute_per_group_blocked_scale_offsets_2d2d_lhs(offsets: torch.Tensor): + """ + Performs round_up(x, 4) on each element in a 1D offsets tensor, + to compute the starting offsets of each group after scaling along the contraction dimension. + + Args: + offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension. + + Returns: + - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. + """ + # Calculate group sizes + zero = torch.tensor([0], dtype=offsets.dtype, device=offsets.device) + group_sizes = torch.diff(offsets, prepend=zero).to(torch.int64) + + # After scaling with block_size 32, each group size up to the nearest multiple of 4 + rounded_group_sizes = ceil_div(group_sizes, 4) * 4 + + # Calculate the starting row after padding for each group + starting_col_after_padding = torch.cumsum(rounded_group_sizes, dim=0) + + # Must start with 0 + starting_col_after_padding = torch.cat([zero, starting_col_after_padding]) + return group_sizes, starting_col_after_padding + + def triton_mx_block_rearrange_per_group_2d( scales_tensor: torch.Tensor, input_group_end_offsets: torch.Tensor, @@ -125,20 +206,22 @@ def triton_mx_block_rearrange_per_group_2d( "Expected element size to be 1 byte (8 bits)" ) rows, cols = scales_tensor.shape - # Calculate blocks needed num_groups = input_group_end_offsets.numel() + # Final offset is the total number of rows in the tensor padded_rows = output_group_start_offsets[-1] + num_col_blocks = ceil_div(cols, 4) padded_cols = num_col_blocks * 4 output = scales_tensor.new_empty((padded_rows, padded_cols)) - # We probably want handle multiple blocks per tile but for now keep it simple - BLOCK_ROWS, BLOCK_COLS = 128, 4 + # Output block stride for the rearranged format + BLOCK_ROWS, BLOCK_COLS = 128, 4 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS output_stride_per_row_of_blocks = ( BLOCK_ROWS * BLOCK_COLS * (padded_cols // BLOCK_COLS) ) + # We parallelize per group and per col block. # Rows per group is variable so we just loop through row blocks per group, per col block. grid = lambda META: ( @@ -176,7 +259,7 @@ def triton_scale_swizzle_per_group_2d( scale_cols, num_groups, orig_offsets, # (num_groups,) - output_scales_ptr, # (rows + num_groups * 128, tl.cdiv(K, 4) * 4) + output_scales_ptr, output_scales_stride_dim0, output_scales_group_offsets, # (num_groups,) output_stride_per_block, @@ -357,3 +440,165 @@ def triton_scale_swizzle_per_group_3d( output_ptr + block_offset + dest_indices_flat, scales_flat, ) + + +def triton_mx_block_rearrange_per_group_2d2d_lhs( + scales_tensor: torch.Tensor, + input_group_end_offsets: torch.Tensor, + output_group_start_offsets: torch.Tensor, +) -> torch.Tensor: + """ + Rearranges an E8M0 tensor scale to block-scaled swizzle format on a per group basis, + where the groups are along the contraction dimension of the GEMM. + + This format is suitable for Tmem as described in NVIDIA documentation: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. + input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales + output_group_start_offsets: tensor of int32 values representing pre-computed group start indexes after blocked format padding + Returns: + - Rearranged tensor in block-scaled swizzle format + """ + assert scales_tensor.ndim == 2, "scales tensor must be 2d" + assert scales_tensor.element_size() == 1, ( + "Expected element size to be 1 byte (8 bits)" + ) + rows, cols = scales_tensor.shape + # Calculate blocks needed + num_groups = input_group_end_offsets.numel() + num_row_blocks = ceil_div(rows, 128) + padded_rows = num_row_blocks * 128 + padded_cols = output_group_start_offsets[-1] + output = scales_tensor.new_empty((padded_rows, padded_cols)) + # We probably want handle multiple blocks per tile but for now keep it simple + BLOCK_ROWS, BLOCK_COLS = 128, 4 + # Output block stride for the rearranged format + output_stride_per_block = BLOCK_ROWS * BLOCK_COLS + num_col_blocks = padded_cols // BLOCK_COLS + output_stride_per_row_of_blocks = output_stride_per_block * num_col_blocks + + # We parallelize per group and per row of blocks. + # Cols per group is variable, so we just loop through col blocks per group, per row block. + grid = lambda META: ( + num_groups, + num_row_blocks, + ) + triton_scale_swizzle_per_group_2d2d_lhs[grid]( + # Input scales + scales_tensor.view(torch.uint8), + scales_tensor.stride(0), + scales_tensor.stride(1), + rows, + cols, + num_groups, + # Original offsets (to read from) + input_group_end_offsets, + # Output scales tensor and group offsets after padding (to write to) + output.view(torch.uint8), + output.stride(0), + output.stride(1), + output_group_start_offsets, + output_stride_per_block, + output_stride_per_row_of_blocks, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) + return output + + +@triton.jit +def triton_scale_swizzle_per_group_2d2d_lhs( + scales_ptr, # (K, total_M//block_size) + scales_stride_dim0, + scales_stride_dim1, + scale_rows, + scale_cols, + num_groups, + orig_offsets, # (num_groups,) + output_scales_ptr, + output_scales_stride_dim0, + output_scales_stride_dim1, + output_scales_group_offsets, # (num_groups,) + output_stride_per_block, + output_stride_per_row_of_blocks_ptr, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, + DEBUG: tl.constexpr = True, +): + group_pid = tl.program_id(0) + block_row_pid = tl.program_id(1) + # Input scales row range for this group + input_group_start_col = tl.load( + orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 + ) + input_group_end_col = tl.load( + orig_offsets + group_pid, mask=group_pid < num_groups, other=0 + ) + # Output scales start row we will begin writing to + output_group_start_col = tl.load( + output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 + ) + + # Output stride per row of blocks is a tensor, so we need to load it + output_stride_per_row_of_blocks = tl.load(output_stride_per_row_of_blocks_ptr) + + # Calculate destination indices for each row and col in block swizzled layout. + # We can reuse this swizzle transformation on each block of data we read. + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] + col_offs = tl.arange(0, BLOCK_COLS)[None, :] + r_div_32 = row_offs // 32 + r_mod_32 = row_offs % 32 + + # Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + + # For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. + # We track how many col blocks we have iterated through. + out_group_block_id = output_group_start_col // BLOCK_COLS + + current_start_col = input_group_start_col + while current_start_col < input_group_end_col: + # Read block of input scales + block_row_offs = block_row_pid * BLOCK_ROWS + row_offs + block_col_offs = current_start_col + col_offs + block_offs = ( + block_row_offs * scales_stride_dim0 + block_col_offs * scales_stride_dim1 + ) + mask = (block_row_offs < scale_rows) & (block_col_offs < input_group_end_col) + input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0) + scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) + + # Calculate block offset using provided output block stride + tgt_row_off = block_row_pid * output_stride_per_row_of_blocks + tgt_col_off = out_group_block_id * output_stride_per_block + if DEBUG: + tl.device_print( + "block_row_pid * BLOCK_ROWS * output_scales_stride_dim0: ", + block_row_pid * BLOCK_ROWS * output_scales_stride_dim0, + ) + tl.device_print( + "block_row_pid * output_stride_per_row_of_blocks: ", + block_row_pid * output_stride_per_row_of_blocks, + ) + tl.device_print("block_row_pid: ", block_row_pid) + tl.device_print("group_pid: ", group_pid) + tl.device_print("tgt_row_block", block_row_pid) + tl.device_print("tgt_col_block", out_group_block_id) + tl.device_print("tgt_row_off: ", tgt_row_off) + tl.device_print("tgt_col_off: ", tgt_col_off) + tl.device_print("global_off:", tgt_row_off + tgt_col_off) + + output_block_offsets = tgt_row_off + tgt_col_off + # Apply swizzling for write to gmem + tl.store( + output_scales_ptr + output_block_offsets + dest_indices_flat, + scales_flat, + ) + # Update row block id to next block + out_group_block_id += 1 + current_start_col += BLOCK_COLS diff --git a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py index 5e215eec5a..06a74cca3f 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py @@ -41,7 +41,7 @@ def fbgemm_mxfp8_grouped_mm_2d_3d( # Convert scales for each group to blocked format. Mg, K = A_fp8.shape A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d( - A_scales, offs, Mg, K + A_scales, offs, K ) B_scales_blocked = torch_to_blocked_per_group_3d(B_scales) From 0c3b8d9f79b05500826f7e5ceee2342f04bdfee9 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 8 Sep 2025 14:31:00 -0700 Subject: [PATCH 2/5] debug --- test/prototype/moe_training/test_kernels.py | 19 ++++-- .../kernels/mxfp8_blocked_scales.py | 62 ++++++++----------- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 67cb57033e..0de769e8ad 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -272,7 +272,7 @@ def test_mxfp8_per_group_blocked_scales_3d( @skip_if_rocm("ROCm enablement in progress") -@pytest.mark.parametrize("m,total_k,n_groups", [(256, 64, 2)]) +@pytest.mark.parametrize("m,total_k,n_groups", [(256, 128, 4)]) def test_mxfp8_per_group_blocked_scales_2d2d_lhs( m: int, total_k: int, @@ -280,16 +280,23 @@ def test_mxfp8_per_group_blocked_scales_2d2d_lhs( ): device = "cuda" block_size = 32 - input_data = torch.randn(m, total_k, device=device) + input_data = torch.cat( + [ + torch.ones(m // 2, total_k, device=device), + torch.full((m // 2, total_k), 999, device=device), + ] + ) + e8m0_scales, _ = to_mx( input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size ) # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets - input_group_offsets = generate_jagged_offs( - n_groups, total_k, multiple_of=block_size, device=device - ) - input_group_offsets //= block_size + # input_group_offsets = generate_jagged_offs( + # n_groups, total_k, multiple_of=block_size, device=device + # ) + # input_group_offsets //= block_size + input_group_offsets = torch.tensor([1, 4], device=device, dtype=torch.int32) # torch reference ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_per_group_2d2d_lhs( diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index 26d3b40170..6ce1714f4e 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -90,23 +90,17 @@ def torch_to_blocked_per_group_2d2d_lhs( # Convert group scales to blocked format group_scales = x_scales[:, group_start_idx:group_end_idx] group_scales_blocked = to_blocked(group_scales) - blocked_scales_list.append(group_scales_blocked) + cols_after_padding = ceil_div(group_size, 4) * 4 + blocked_scales_list.append(group_scales_blocked.reshape(-1, cols_after_padding)) # Calculate the start row after padding - cols_after_padding = ceil_div(group_size, 4) * 4 new_start_col = prev_start_row_after_padding + cols_after_padding start_col_after_padding_list.append(new_start_col) # Update next group start index group_start_idx = group_end_idx - M = x_scales.shape[0] - padded_M = ceil_div(M, 128) * 128 - # blocked_scales = torch.cat(blocked_scales_list, dim=0) - # blocked_scales = blocked_scales.reshape(padded_M, -1) - blocked_scales = torch.cat( - [s.reshape(padded_M, -1) for s in blocked_scales_list], dim=1 - ) + blocked_scales = torch.cat(blocked_scales_list, dim=1) start_cols_after_padding = torch.tensor( start_col_after_padding_list, device=x_scales.device, dtype=torch.int64 ) @@ -286,10 +280,13 @@ def triton_scale_swizzle_per_group_2d( col_offs = tl.arange(0, BLOCK_COLS)[None, :] r_div_32 = row_offs // 32 r_mod_32 = row_offs % 32 + # Rearrange to (32, 4, 4) then to final (32, 16) coordinates dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs + # Flatten dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + # For this group and col block, we iterate through row blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. # We track how many row blocks we have iterated through. block_row_id = 0 @@ -470,17 +467,19 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( num_groups = input_group_end_offsets.numel() num_row_blocks = ceil_div(rows, 128) padded_rows = num_row_blocks * 128 + + # output_group_start_offsets always starts with 0 and ends with the total number of cols padded_cols = output_group_start_offsets[-1] output = scales_tensor.new_empty((padded_rows, padded_cols)) - # We probably want handle multiple blocks per tile but for now keep it simple - BLOCK_ROWS, BLOCK_COLS = 128, 4 + # Output block stride for the rearranged format + BLOCK_ROWS, BLOCK_COLS = 128, 4 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS num_col_blocks = padded_cols // BLOCK_COLS output_stride_per_row_of_blocks = output_stride_per_block * num_col_blocks - # We parallelize per group and per row of blocks. - # Cols per group is variable, so we just loop through col blocks per group, per row block. + # We parallelize per group and per row block. + # Cols per group is variable, so we just loop through col blocks for each group. grid = lambda META: ( num_groups, num_row_blocks, @@ -529,6 +528,7 @@ def triton_scale_swizzle_per_group_2d2d_lhs( ): group_pid = tl.program_id(0) block_row_pid = tl.program_id(1) + # Input scales row range for this group input_group_start_col = tl.load( orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 @@ -541,8 +541,7 @@ def triton_scale_swizzle_per_group_2d2d_lhs( output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 ) - # Output stride per row of blocks is a tensor, so we need to load it - output_stride_per_row_of_blocks = tl.load(output_stride_per_row_of_blocks_ptr) + out_stride_per_row_of_blocks = tl.load(output_stride_per_row_of_blocks_ptr) # Calculate destination indices for each row and col in block swizzled layout. # We can reuse this swizzle transformation on each block of data we read. @@ -559,13 +558,12 @@ def triton_scale_swizzle_per_group_2d2d_lhs( # For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. # We track how many col blocks we have iterated through. - out_group_block_id = output_group_start_col // BLOCK_COLS - - current_start_col = input_group_start_col - while current_start_col < input_group_end_col: + curr_input_start_col = input_group_start_col + curr_out_start_col_block = output_group_start_col // BLOCK_COLS + while curr_input_start_col < input_group_end_col: # Read block of input scales block_row_offs = block_row_pid * BLOCK_ROWS + row_offs - block_col_offs = current_start_col + col_offs + block_col_offs = curr_input_start_col + col_offs block_offs = ( block_row_offs * scales_stride_dim0 + block_col_offs * scales_stride_dim1 ) @@ -574,31 +572,25 @@ def triton_scale_swizzle_per_group_2d2d_lhs( scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) # Calculate block offset using provided output block stride - tgt_row_off = block_row_pid * output_stride_per_row_of_blocks - tgt_col_off = out_group_block_id * output_stride_per_block + tgt_row_off = block_row_pid * out_stride_per_row_of_blocks + tgt_col_off = curr_out_start_col_block * output_stride_per_block + + output_block_offsets = tgt_row_off + tgt_col_off if DEBUG: - tl.device_print( - "block_row_pid * BLOCK_ROWS * output_scales_stride_dim0: ", - block_row_pid * BLOCK_ROWS * output_scales_stride_dim0, - ) - tl.device_print( - "block_row_pid * output_stride_per_row_of_blocks: ", - block_row_pid * output_stride_per_row_of_blocks, - ) tl.device_print("block_row_pid: ", block_row_pid) tl.device_print("group_pid: ", group_pid) tl.device_print("tgt_row_block", block_row_pid) - tl.device_print("tgt_col_block", out_group_block_id) + tl.device_print("tgt_col_block", curr_out_start_col_block) tl.device_print("tgt_row_off: ", tgt_row_off) tl.device_print("tgt_col_off: ", tgt_col_off) tl.device_print("global_off:", tgt_row_off + tgt_col_off) - output_block_offsets = tgt_row_off + tgt_col_off # Apply swizzling for write to gmem tl.store( output_scales_ptr + output_block_offsets + dest_indices_flat, scales_flat, ) - # Update row block id to next block - out_group_block_id += 1 - current_start_col += BLOCK_COLS + + # Advance to next col block + curr_input_start_col += BLOCK_COLS + curr_out_start_col_block += 1 From e86811adcc2b1544d2eeb6fb528f082a70228cc2 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 8 Sep 2025 16:08:21 -0700 Subject: [PATCH 3/5] debug --- .../moe_training/kernels/mxfp8_blocked_scales.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index 6ce1714f4e..675ccd55d0 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -491,13 +491,12 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( scales_tensor.stride(1), rows, cols, + padded_rows, num_groups, # Original offsets (to read from) input_group_end_offsets, # Output scales tensor and group offsets after padding (to write to) output.view(torch.uint8), - output.stride(0), - output.stride(1), output_group_start_offsets, output_stride_per_block, output_stride_per_row_of_blocks, @@ -514,11 +513,10 @@ def triton_scale_swizzle_per_group_2d2d_lhs( scales_stride_dim1, scale_rows, scale_cols, + padded_rows, num_groups, orig_offsets, # (num_groups,) output_scales_ptr, - output_scales_stride_dim0, - output_scales_stride_dim1, output_scales_group_offsets, # (num_groups,) output_stride_per_block, output_stride_per_row_of_blocks_ptr, @@ -577,13 +575,15 @@ def triton_scale_swizzle_per_group_2d2d_lhs( output_block_offsets = tgt_row_off + tgt_col_off if DEBUG: - tl.device_print("block_row_pid: ", block_row_pid) + tl.device_print("\nblock_row_pid: ", block_row_pid) tl.device_print("group_pid: ", group_pid) tl.device_print("tgt_row_block", block_row_pid) + tl.device_print("output_group_start_col: ", output_group_start_col) tl.device_print("tgt_col_block", curr_out_start_col_block) tl.device_print("tgt_row_off: ", tgt_row_off) tl.device_print("tgt_col_off: ", tgt_col_off) tl.device_print("global_off:", tgt_row_off + tgt_col_off) + tl.device_print("writing: ", scales_flat) # Apply swizzling for write to gmem tl.store( From 11f25ec0d204b8e7b694ce0915719bd36486d128 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 8 Sep 2025 18:17:25 -0700 Subject: [PATCH 4/5] col of blocks method --- test/prototype/moe_training/test_kernels.py | 13 +++++++-- .../kernels/mxfp8_blocked_scales.py | 27 ++++++++++--------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 0de769e8ad..6732da2cb8 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -272,7 +272,7 @@ def test_mxfp8_per_group_blocked_scales_3d( @skip_if_rocm("ROCm enablement in progress") -@pytest.mark.parametrize("m,total_k,n_groups", [(256, 128, 4)]) +@pytest.mark.parametrize("m,total_k,n_groups", [(256, 512, 4)])#, (256, 128, 4), (512, 128, 4), (1024, 128, 4), (1024, 256, 4), (1024, 512, 4), (1024, 1024, 4), (1024, 2048, 4), (1024, 4096, 4), (1024, 8192, 4), (1024, 16384, 4)]) def test_mxfp8_per_group_blocked_scales_2d2d_lhs( m: int, total_k: int, @@ -280,12 +280,15 @@ def test_mxfp8_per_group_blocked_scales_2d2d_lhs( ): device = "cuda" block_size = 32 + + # Make each group of row blocks have distinct, constinent data for debugging input_data = torch.cat( [ torch.ones(m // 2, total_k, device=device), torch.full((m // 2, total_k), 999, device=device), ] ) + #input_data= torch.randn(m, total_k, device=device) e8m0_scales, _ = to_mx( input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size @@ -296,7 +299,8 @@ def test_mxfp8_per_group_blocked_scales_2d2d_lhs( # n_groups, total_k, multiple_of=block_size, device=device # ) # input_group_offsets //= block_size - input_group_offsets = torch.tensor([1, 4], device=device, dtype=torch.int32) + input_group_offsets = torch.tensor([3, 8, 12, 16], device=device, dtype=torch.int32) + #print(input_group_offsets) # torch reference ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_per_group_2d2d_lhs( @@ -316,6 +320,11 @@ def test_mxfp8_per_group_blocked_scales_2d2d_lhs( input_group_offsets, output_group_offsets, ) + print(ref_start_cols_after_padding) + with open('tmp-ref.txt', 'w') as f: + f.write(str(ref_out_scales.storage())) + with open('tmp-triton.txt', 'w') as f: + f.write(str(triton_out_scales.storage())) breakpoint() assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( "blocked scales not equal" diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index 675ccd55d0..9af7c77e1d 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -91,7 +91,7 @@ def torch_to_blocked_per_group_2d2d_lhs( group_scales = x_scales[:, group_start_idx:group_end_idx] group_scales_blocked = to_blocked(group_scales) cols_after_padding = ceil_div(group_size, 4) * 4 - blocked_scales_list.append(group_scales_blocked.reshape(-1, cols_after_padding)) + blocked_scales_list.append(group_scales_blocked) # Calculate the start row after padding new_start_col = prev_start_row_after_padding + cols_after_padding @@ -100,7 +100,11 @@ def torch_to_blocked_per_group_2d2d_lhs( # Update next group start index group_start_idx = group_end_idx - blocked_scales = torch.cat(blocked_scales_list, dim=1) + # blocked_scales = torch.cat(blocked_scales_list, dim=1) + M = x_scales.shape[0] + padded_M = ceil_div(M, 128) * 128 + blocked_scales = torch.cat(blocked_scales_list) + blocked_scales = blocked_scales.reshape(padded_M, -1) start_cols_after_padding = torch.tensor( start_col_after_padding_list, device=x_scales.device, dtype=torch.int64 ) @@ -475,8 +479,8 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( # Output block stride for the rearranged format BLOCK_ROWS, BLOCK_COLS = 128, 4 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS - num_col_blocks = padded_cols // BLOCK_COLS - output_stride_per_row_of_blocks = output_stride_per_block * num_col_blocks + num_row_blocks = padded_rows // BLOCK_ROWS + output_stride_per_col_of_blocks = output_stride_per_block * num_row_blocks # We parallelize per group and per row block. # Cols per group is variable, so we just loop through col blocks for each group. @@ -491,7 +495,6 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( scales_tensor.stride(1), rows, cols, - padded_rows, num_groups, # Original offsets (to read from) input_group_end_offsets, @@ -499,9 +502,10 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( output.view(torch.uint8), output_group_start_offsets, output_stride_per_block, - output_stride_per_row_of_blocks, + output_stride_per_col_of_blocks, BLOCK_ROWS=BLOCK_ROWS, BLOCK_COLS=BLOCK_COLS, + DEBUG=True, ) return output @@ -513,16 +517,15 @@ def triton_scale_swizzle_per_group_2d2d_lhs( scales_stride_dim1, scale_rows, scale_cols, - padded_rows, num_groups, orig_offsets, # (num_groups,) output_scales_ptr, output_scales_group_offsets, # (num_groups,) output_stride_per_block, - output_stride_per_row_of_blocks_ptr, + output_stride_per_col_of_blocks, BLOCK_ROWS: tl.constexpr, BLOCK_COLS: tl.constexpr, - DEBUG: tl.constexpr = True, + DEBUG: tl.constexpr = False, ): group_pid = tl.program_id(0) block_row_pid = tl.program_id(1) @@ -539,8 +542,6 @@ def triton_scale_swizzle_per_group_2d2d_lhs( output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 ) - out_stride_per_row_of_blocks = tl.load(output_stride_per_row_of_blocks_ptr) - # Calculate destination indices for each row and col in block swizzled layout. # We can reuse this swizzle transformation on each block of data we read. row_offs = tl.arange(0, BLOCK_ROWS)[:, None] @@ -570,8 +571,8 @@ def triton_scale_swizzle_per_group_2d2d_lhs( scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) # Calculate block offset using provided output block stride - tgt_row_off = block_row_pid * out_stride_per_row_of_blocks - tgt_col_off = curr_out_start_col_block * output_stride_per_block + tgt_row_off = block_row_pid * output_stride_per_block + tgt_col_off = curr_out_start_col_block * output_stride_per_col_of_blocks output_block_offsets = tgt_row_off + tgt_col_off if DEBUG: From 213f19b205b06bc56e25307aeb30342d7e82b179 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Mon, 8 Sep 2025 18:43:46 -0700 Subject: [PATCH 5/5] row of blocks within groups only --- .../benchmark_2d_3d_grouped_gemms.py | 4 +- ...chmark_2d_blocked_swizzle_scale_kernels.py | 14 +- test/prototype/moe_training/test_kernels.py | 62 +++---- .../kernels/mxfp8_blocked_scales.py | 169 ++++++++++-------- .../moe_training/kernels/mxfp8_gemms.py | 4 +- 5 files changed, 127 insertions(+), 126 deletions(-) diff --git a/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py b/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py index 1dc78ec0a5..8caadc4fe3 100644 --- a/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py +++ b/benchmarks/prototype/moe_training/benchmark_2d_3d_grouped_gemms.py @@ -18,7 +18,7 @@ from torchao.float8.config import ScalingGranularity from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( - torch_to_blocked_per_group_2d, + torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, ) from torchao.prototype.moe_training.utils import generate_jagged_offs @@ -230,7 +230,7 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float: # Convert scales for each group to blocked format. Mg, K = A_fp8.shape - A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d( + A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups( A_scales, offs, K ) B_scales_blocked = torch_to_blocked_per_group_3d(B_scales) diff --git a/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py b/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py index f1185bd533..84a8f040cb 100644 --- a/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py +++ b/benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py @@ -15,9 +15,9 @@ from benchmarks.utils import benchmark_cuda_function_in_microseconds from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( - compute_per_group_blocked_scale_offsets, - torch_to_blocked_per_group_2d, - triton_mx_block_rearrange_per_group_2d, + compute_blocked_scale_offsets_for_M_groups, + torch_to_blocked_2d_M_groups, + triton_mx_block_rearrange_2d_M_groups, ) from torchao.prototype.moe_training.utils import generate_jagged_offs @@ -82,7 +82,7 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32) # bench torch - compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d) + compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups) torch_out_scales, torch_group_offs = compiled_run_torch( input_tensor, input_group_offsets, K ) @@ -95,16 +95,16 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult: ) # bench triton - _, output_group_offsets = compute_per_group_blocked_scale_offsets( + _, output_group_offsets = compute_blocked_scale_offsets_for_M_groups( input_group_offsets ) - triton_out_scales = triton_mx_block_rearrange_per_group_2d( + triton_out_scales = triton_mx_block_rearrange_2d_M_groups( input_tensor, input_group_offsets, output_group_offsets, ) triton_time_us = benchmark_cuda_function_in_microseconds( - triton_mx_block_rearrange_per_group_2d, + triton_mx_block_rearrange_2d_M_groups, input_tensor, input_group_offsets, output_group_offsets, diff --git a/test/prototype/moe_training/test_kernels.py b/test/prototype/moe_training/test_kernels.py index 6732da2cb8..e8fe088f98 100644 --- a/test/prototype/moe_training/test_kernels.py +++ b/test/prototype/moe_training/test_kernels.py @@ -22,13 +22,13 @@ triton_fp8_per_group_rowwise_scales, ) from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( - compute_per_group_blocked_scale_offsets, - compute_per_group_blocked_scale_offsets_2d2d_lhs, - torch_to_blocked_per_group_2d, - torch_to_blocked_per_group_2d2d_lhs, + compute_blocked_scale_offsets_for_K_groups, + compute_blocked_scale_offsets_for_M_groups, + torch_to_blocked_2d_K_groups, + torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, - triton_mx_block_rearrange_per_group_2d, - triton_mx_block_rearrange_per_group_2d2d_lhs, + triton_mx_block_rearrange_2d_K_groups, + triton_mx_block_rearrange_2d_M_groups, triton_mx_block_rearrange_per_group_3d, ) from torchao.prototype.moe_training.utils import ( @@ -229,15 +229,15 @@ def test_mxfp8_per_group_blocked_scales_2d( ) # torch reference - ref_out_scales, _ = torch_to_blocked_per_group_2d( + ref_out_scales, _ = torch_to_blocked_2d_M_groups( e8m0_scales, input_group_offsets, k, block_size=block_size ) # triton kernel - _, output_group_offsets = compute_per_group_blocked_scale_offsets( + _, output_group_offsets = compute_blocked_scale_offsets_for_M_groups( input_group_offsets ) - triton_out_scales = triton_mx_block_rearrange_per_group_2d( + triton_out_scales = triton_mx_block_rearrange_2d_M_groups( e8m0_scales, input_group_offsets, output_group_offsets, @@ -272,60 +272,44 @@ def test_mxfp8_per_group_blocked_scales_3d( @skip_if_rocm("ROCm enablement in progress") -@pytest.mark.parametrize("m,total_k,n_groups", [(256, 512, 4)])#, (256, 128, 4), (512, 128, 4), (1024, 128, 4), (1024, 256, 4), (1024, 512, 4), (1024, 1024, 4), (1024, 2048, 4), (1024, 4096, 4), (1024, 8192, 4), (1024, 16384, 4)]) -def test_mxfp8_per_group_blocked_scales_2d2d_lhs( +@pytest.mark.parametrize("m", [256, 512, 1024, 5120]) +@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384]) +@pytest.mark.parametrize("n_groups", [1, 4, 8, 16]) +def test_mxfp8_per_group_blocked_scales_2d2d( m: int, total_k: int, n_groups: int, ): device = "cuda" block_size = 32 - - # Make each group of row blocks have distinct, constinent data for debugging - input_data = torch.cat( - [ - torch.ones(m // 2, total_k, device=device), - torch.full((m // 2, total_k), 999, device=device), - ] - ) - #input_data= torch.randn(m, total_k, device=device) + input_data = torch.randn(m, total_k, device=device) e8m0_scales, _ = to_mx( input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size ) # Generate group end offsets along total_K, then divide by block_size to get scale group end offsets - # input_group_offsets = generate_jagged_offs( - # n_groups, total_k, multiple_of=block_size, device=device - # ) - # input_group_offsets //= block_size - input_group_offsets = torch.tensor([3, 8, 12, 16], device=device, dtype=torch.int32) - #print(input_group_offsets) + input_group_offsets = generate_jagged_offs( + n_groups, total_k, multiple_of=block_size, device=device + ) + input_group_offsets //= block_size # torch reference - ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_per_group_2d2d_lhs( + ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_2d_K_groups( e8m0_scales, input_group_offsets, ) # triton kernel - _, output_group_offsets = compute_per_group_blocked_scale_offsets_2d2d_lhs( + _, output_group_offsets = compute_blocked_scale_offsets_for_K_groups( input_group_offsets ) - assert torch.allclose(output_group_offsets, ref_start_cols_after_padding), ( + assert torch.equal(output_group_offsets, ref_start_cols_after_padding), ( "output scale group start offsets not equal" ) - triton_out_scales = triton_mx_block_rearrange_per_group_2d2d_lhs( + triton_out_scales = triton_mx_block_rearrange_2d_K_groups( e8m0_scales, input_group_offsets, output_group_offsets, ) - print(ref_start_cols_after_padding) - with open('tmp-ref.txt', 'w') as f: - f.write(str(ref_out_scales.storage())) - with open('tmp-triton.txt', 'w') as f: - f.write(str(triton_out_scales.storage())) - breakpoint() - assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), ( - "blocked scales not equal" - ) + assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal" diff --git a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py index 9af7c77e1d..48c248a7d0 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_blocked_scales.py @@ -9,16 +9,17 @@ from torchao.utils import ceil_div -def torch_to_blocked_per_group_2d( +def torch_to_blocked_2d_M_groups( x_scales: Tensor, group_offs: Tensor, K: int, block_size: int = 32 ) -> Tuple[Tensor, Tensor]: """ - Convert scales to blocked format for a 2D tensor (input activations / token groups) + Convert scales to blocked format for a 2D tensor (input activations / token groups), + where groups are along the total_M dimension (rows). Args: x_scales: Tensor with per group scales in blocked format concatenated into one tensor. - group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension. - Mg: total size of all groups summed together + group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the total_M dimension. + total_M: total size of all groups summed together K: K dim size Returns: @@ -60,11 +61,12 @@ def torch_to_blocked_per_group_2d( return blocked_scales, start_row_after_padding -def torch_to_blocked_per_group_2d2d_lhs( +def torch_to_blocked_2d_K_groups( x_scales: Tensor, group_offs: Tensor, block_size: int = 32 ) -> Tuple[Tensor, Tensor]: """ - Convert scales to blocked format for a 2D tensor (input activations) when scaling along the contraction dimension. + Convert scales to blocked format for a 2D tensor (input activations), + when groups are along the scaled (K) dimension. Args: x_scales: Tensor with per group scales in blocked format concatenated into one tensor. @@ -131,12 +133,15 @@ def torch_to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor: return weight_scales_blocked -def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor): +def compute_blocked_scale_offsets_for_M_groups(offsets: torch.Tensor): """ - Rounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128. + Given a 1D tensor of input group offsets along the total_M dimension (rows), + compute the starting row offset of the scales for each group after padding to blocked format. + + In effect, this rrounds each integer in a 1D PyTorch tensor up to the nearest multiple of 128. Args: - offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension. + - offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the total_M dimension. Returns: - group_sizes: A 1D PyTorch tensor of integers representing the size of each group. @@ -157,13 +162,13 @@ def compute_per_group_blocked_scale_offsets(offsets: torch.Tensor): return group_sizes, starting_row_after_padding -def compute_per_group_blocked_scale_offsets_2d2d_lhs(offsets: torch.Tensor): +def compute_blocked_scale_offsets_for_K_groups(offsets: torch.Tensor): """ Performs round_up(x, 4) on each element in a 1D offsets tensor, to compute the starting offsets of each group after scaling along the contraction dimension. Args: - offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the Mg dimension. + offsets: A 1D PyTorch tensor of integers in ascending sorted order, representing the end index of each group along the total_M dimension. Returns: - starting_row_after_padding: 1D integer tensor representing the starting row after padding each to blocked format. @@ -183,15 +188,18 @@ def compute_per_group_blocked_scale_offsets_2d2d_lhs(offsets: torch.Tensor): return group_sizes, starting_col_after_padding -def triton_mx_block_rearrange_per_group_2d( +def triton_mx_block_rearrange_2d_M_groups( scales_tensor: torch.Tensor, input_group_end_offsets: torch.Tensor, output_group_start_offsets: torch.Tensor, ) -> torch.Tensor: """ - Rearranges an E8M0 tensor scale to block-scaled swizzle format. + Rearranges an E8M0 tensor scale to block-scaled swizzle format, + where groups are along the total_M dimension (rows). + This format is suitable for Tmem as described in NVIDIA documentation: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + Args: scales_tensor: Input tensor containing e8m0 scales for each logical group of a target tensor. input_group_end_offsets: tensor of int32 values representing group end indexes for the input scales @@ -226,7 +234,7 @@ def triton_mx_block_rearrange_per_group_2d( num_groups, num_col_blocks, ) - triton_scale_swizzle_per_group_2d[grid]( + triton_scale_swizzle_M_groups[grid]( # Input scales scales_tensor.view(torch.uint8), scales_tensor.stride(0), @@ -249,7 +257,7 @@ def triton_mx_block_rearrange_per_group_2d( @triton.jit -def triton_scale_swizzle_per_group_2d( +def triton_scale_swizzle_M_groups( scales_ptr, # (M, K//block_size) scales_stride_dim0, scales_stride_dim1, @@ -282,14 +290,14 @@ def triton_scale_swizzle_per_group_2d( # We can reuse this swizzle transformation on each block of data we read. row_offs = tl.arange(0, BLOCK_ROWS)[:, None] col_offs = tl.arange(0, BLOCK_COLS)[None, :] - r_div_32 = row_offs // 32 - r_mod_32 = row_offs % 32 - - # Rearrange to (32, 4, 4) then to final (32, 16) coordinates - dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs - # Flatten - dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + # Compute desination indices for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) # For this group and col block, we iterate through row blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. # We track how many row blocks we have iterated through. @@ -406,14 +414,22 @@ def triton_scale_swizzle_per_group_3d( input_ptr += pid_group * input_stride_dim0 output_ptr += pid_group * output_stride_dim0 - rows = tl.arange(0, BLOCK_ROWS)[:, None] - cols = tl.arange(0, BLOCK_COLS)[None, :] + row_offs = tl.arange(0, BLOCK_ROWS)[:, None] + col_offs = tl.arange(0, BLOCK_COLS)[None, :] + + # Compute desination offs for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) # Calculate starting row and column for this tile start_row = pid_row * BLOCK_ROWS start_col = pid_col * BLOCK_COLS - global_rows = start_row + rows - global_cols = start_col + cols + global_rows = start_row + row_offs + global_cols = start_col + col_offs mask = (global_rows < scale_rows) & (global_cols < scale_cols) @@ -422,15 +438,6 @@ def triton_scale_swizzle_per_group_3d( mask=mask, other=0.0, ) - - r_div_32 = rows // 32 - r_mod_32 = rows % 32 - - # 2) Rearrange to (32, 4, 4) then to final (32, 16) coordinates - dest_indices = r_mod_32 * 16 + r_div_32 * 4 + cols - - # Flatten - dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) # Calculate block offset using provided output block stride @@ -443,7 +450,7 @@ def triton_scale_swizzle_per_group_3d( ) -def triton_mx_block_rearrange_per_group_2d2d_lhs( +def triton_mx_block_rearrange_2d_K_groups( scales_tensor: torch.Tensor, input_group_end_offsets: torch.Tensor, output_group_start_offsets: torch.Tensor, @@ -479,8 +486,6 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( # Output block stride for the rearranged format BLOCK_ROWS, BLOCK_COLS = 128, 4 output_stride_per_block = BLOCK_ROWS * BLOCK_COLS - num_row_blocks = padded_rows // BLOCK_ROWS - output_stride_per_col_of_blocks = output_stride_per_block * num_row_blocks # We parallelize per group and per row block. # Cols per group is variable, so we just loop through col blocks for each group. @@ -488,13 +493,14 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( num_groups, num_row_blocks, ) - triton_scale_swizzle_per_group_2d2d_lhs[grid]( + triton_scale_swizzle_2d_K_groups[grid]( # Input scales scales_tensor.view(torch.uint8), scales_tensor.stride(0), scales_tensor.stride(1), rows, cols, + padded_rows, num_groups, # Original offsets (to read from) input_group_end_offsets, @@ -502,27 +508,26 @@ def triton_mx_block_rearrange_per_group_2d2d_lhs( output.view(torch.uint8), output_group_start_offsets, output_stride_per_block, - output_stride_per_col_of_blocks, BLOCK_ROWS=BLOCK_ROWS, BLOCK_COLS=BLOCK_COLS, - DEBUG=True, + DEBUG=False, ) return output @triton.jit -def triton_scale_swizzle_per_group_2d2d_lhs( - scales_ptr, # (K, total_M//block_size) +def triton_scale_swizzle_2d_K_groups( + scales_ptr, # (M, total_K//block_size) scales_stride_dim0, scales_stride_dim1, scale_rows, scale_cols, + padded_rows, num_groups, orig_offsets, # (num_groups,) output_scales_ptr, output_scales_group_offsets, # (num_groups,) output_stride_per_block, - output_stride_per_col_of_blocks, BLOCK_ROWS: tl.constexpr, BLOCK_COLS: tl.constexpr, DEBUG: tl.constexpr = False, @@ -534,31 +539,27 @@ def triton_scale_swizzle_per_group_2d2d_lhs( input_group_start_col = tl.load( orig_offsets + group_pid - 1, mask=group_pid > 0, other=0 ) - input_group_end_col = tl.load( - orig_offsets + group_pid, mask=group_pid < num_groups, other=0 - ) + input_group_end_col = tl.load(orig_offsets + group_pid) + # Output scales start row we will begin writing to - output_group_start_col = tl.load( - output_scales_group_offsets + group_pid, mask=group_pid < num_groups, other=0 - ) + output_group_start_col = tl.load(output_scales_group_offsets + group_pid) - # Calculate destination indices for each row and col in block swizzled layout. - # We can reuse this swizzle transformation on each block of data we read. row_offs = tl.arange(0, BLOCK_ROWS)[:, None] col_offs = tl.arange(0, BLOCK_COLS)[None, :] - r_div_32 = row_offs // 32 - r_mod_32 = row_offs % 32 - - # Rearrange to (32, 4, 4) then to final (32, 16) coordinates - dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs - # Flatten - dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + # Compute desination offs for each elem in block swizzled layout + dest_indices_flat = _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS=BLOCK_ROWS, + BLOCK_COLS=BLOCK_COLS, + ) # For this group and row block, we iterate through col blocks, reading (BLOCK_ROWS, BLOCK_COLS) from the input scales. # We track how many col blocks we have iterated through. + out_group_base_offset = output_group_start_col * padded_rows curr_input_start_col = input_group_start_col - curr_out_start_col_block = output_group_start_col // BLOCK_COLS + curr_out_start_col_block = 0 while curr_input_start_col < input_group_end_col: # Read block of input scales block_row_offs = block_row_pid * BLOCK_ROWS + row_offs @@ -570,28 +571,44 @@ def triton_scale_swizzle_per_group_2d2d_lhs( input_scales = tl.load(scales_ptr + block_offs, mask=mask, other=0.0) scales_flat = tl.reshape(input_scales, (BLOCK_ROWS * BLOCK_COLS)) - # Calculate block offset using provided output block stride - tgt_row_off = block_row_pid * output_stride_per_block - tgt_col_off = curr_out_start_col_block * output_stride_per_col_of_blocks - - output_block_offsets = tgt_row_off + tgt_col_off - if DEBUG: - tl.device_print("\nblock_row_pid: ", block_row_pid) - tl.device_print("group_pid: ", group_pid) - tl.device_print("tgt_row_block", block_row_pid) - tl.device_print("output_group_start_col: ", output_group_start_col) - tl.device_print("tgt_col_block", curr_out_start_col_block) - tl.device_print("tgt_row_off: ", tgt_row_off) - tl.device_print("tgt_col_off: ", tgt_col_off) - tl.device_print("global_off:", tgt_row_off + tgt_col_off) - tl.device_print("writing: ", scales_flat) + # Get offset within the group to add to the group's base offset + num_cols_in_group = input_group_end_col - input_group_start_col + num_col_blocks_in_group = tl.cdiv(num_cols_in_group, BLOCK_COLS) + stride_per_row_of_blocks_in_group = ( + num_col_blocks_in_group * output_stride_per_block + ) + offset_in_group = ( + block_row_pid * stride_per_row_of_blocks_in_group + + curr_out_start_col_block * output_stride_per_block + ) + final_offset = out_group_base_offset + offset_in_group # Apply swizzling for write to gmem tl.store( - output_scales_ptr + output_block_offsets + dest_indices_flat, + output_scales_ptr + final_offset + dest_indices_flat, scales_flat, ) # Advance to next col block curr_input_start_col += BLOCK_COLS curr_out_start_col_block += 1 + + +@triton.jit +def _dest_indices_for_block( + row_offs, + col_offs, + BLOCK_ROWS: tl.constexpr, + BLOCK_COLS: tl.constexpr, +): + # Calculate destination indices for each row and col in block swizzled layout. + # We can reuse this swizzle transformation on each block of data we read. + r_div_32 = row_offs // 32 + r_mod_32 = row_offs % 32 + + # Rearrange to (32, 4, 4) then to final (32, 16) coordinates + dest_indices = r_mod_32 * 16 + r_div_32 * 4 + col_offs + + # Flatten + dest_indices_flat = tl.reshape(dest_indices, (BLOCK_ROWS * BLOCK_COLS)) + return dest_indices_flat diff --git a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py index 06a74cca3f..4f419f4c6f 100644 --- a/torchao/prototype/moe_training/kernels/mxfp8_gemms.py +++ b/torchao/prototype/moe_training/kernels/mxfp8_gemms.py @@ -3,7 +3,7 @@ import torch from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import ( - torch_to_blocked_per_group_2d, + torch_to_blocked_2d_M_groups, torch_to_blocked_per_group_3d, ) @@ -40,7 +40,7 @@ def fbgemm_mxfp8_grouped_mm_2d_3d( # Convert scales for each group to blocked format. Mg, K = A_fp8.shape - A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d( + A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups( A_scales, offs, K ) B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)