Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torchao.float8.config import ScalingGranularity
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
torch_to_blocked_per_group_2d,
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs
Expand Down Expand Up @@ -230,8 +230,8 @@ def bench_mxfp8_grouped_mm(A, B_t, offs, block_size=32) -> float:

# Convert scales for each group to blocked format.
Mg, K = A_fp8.shape
A_scales_blocked, starting_row_after_padding = torch_to_blocked_per_group_2d(
A_scales, offs, Mg, K
A_scales_blocked, starting_row_after_padding = torch_to_blocked_2d_M_groups(
A_scales, offs, K
)
B_scales_blocked = torch_to_blocked_per_group_3d(B_scales)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

from benchmarks.utils import benchmark_cuda_function_in_microseconds
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
compute_per_group_blocked_scale_offsets,
torch_to_blocked_per_group_2d,
triton_mx_block_rearrange_per_group_2d,
compute_blocked_scale_offsets_for_M_groups,
torch_to_blocked_2d_M_groups,
triton_mx_block_rearrange_2d_M_groups,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs

Expand Down Expand Up @@ -82,9 +82,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
input_group_offsets = generate_jagged_offs(num_groups, Mg, multiple_of=32)

# bench torch
compiled_run_torch = torch.compile(torch_to_blocked_per_group_2d)
compiled_run_torch = torch.compile(torch_to_blocked_2d_M_groups)
torch_out_scales, torch_group_offs = compiled_run_torch(
input_tensor, input_group_offsets, Mg, K
input_tensor, input_group_offsets, K
)
torch_time_us = benchmark_cuda_function_in_microseconds(
compiled_run_torch,
Expand All @@ -95,16 +95,16 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
)

# bench triton
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
input_group_offsets
)
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
input_tensor,
input_group_offsets,
output_group_offsets,
)
triton_time_us = benchmark_cuda_function_in_microseconds(
triton_mx_block_rearrange_per_group_2d,
triton_mx_block_rearrange_2d_M_groups,
input_tensor,
input_group_offsets,
output_group_offsets,
Expand Down
61 changes: 54 additions & 7 deletions test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
triton_fp8_per_group_rowwise_scales,
)
from torchao.prototype.moe_training.kernels.mxfp8_blocked_scales import (
compute_per_group_blocked_scale_offsets,
torch_to_blocked_per_group_2d,
compute_blocked_scale_offsets_for_K_groups,
compute_blocked_scale_offsets_for_M_groups,
torch_to_blocked_2d_K_groups,
torch_to_blocked_2d_M_groups,
torch_to_blocked_per_group_3d,
triton_mx_block_rearrange_per_group_2d,
triton_mx_block_rearrange_2d_K_groups,
triton_mx_block_rearrange_2d_M_groups,
triton_mx_block_rearrange_per_group_3d,
)
from torchao.prototype.moe_training.utils import (
Expand Down Expand Up @@ -226,15 +229,15 @@ def test_mxfp8_per_group_blocked_scales_2d(
)

# torch reference
ref_out_scales, _ = torch_to_blocked_per_group_2d(
e8m0_scales, input_group_offsets, m, k, block_size=block_size
ref_out_scales, _ = torch_to_blocked_2d_M_groups(
e8m0_scales, input_group_offsets, k, block_size=block_size
)

# triton kernel
_, output_group_offsets = compute_per_group_blocked_scale_offsets(
_, output_group_offsets = compute_blocked_scale_offsets_for_M_groups(
input_group_offsets
)
triton_out_scales = triton_mx_block_rearrange_per_group_2d(
triton_out_scales = triton_mx_block_rearrange_2d_M_groups(
e8m0_scales,
input_group_offsets,
output_group_offsets,
Expand Down Expand Up @@ -266,3 +269,47 @@ def test_mxfp8_per_group_blocked_scales_3d(
assert torch.allclose(ref_out_scales, triton_out_scales, atol=0, rtol=0), (
"blocked scales not equal"
)


@skip_if_rocm("ROCm enablement in progress")
@pytest.mark.parametrize("m", [256, 512, 1024, 5120])
@pytest.mark.parametrize("total_k", [512, 1024, 2048, 4096, 8192, 16384])
@pytest.mark.parametrize("n_groups", [1, 4, 8, 16])
def test_mxfp8_per_group_blocked_scales_2d2d(
m: int,
total_k: int,
n_groups: int,
):
device = "cuda"
block_size = 32
input_data = torch.randn(m, total_k, device=device)

e8m0_scales, _ = to_mx(
input_data, elem_dtype=torch.float8_e4m3fn, block_size=block_size
)

# Generate group end offsets along total_K, then divide by block_size to get scale group end offsets
input_group_offsets = generate_jagged_offs(
n_groups, total_k, multiple_of=block_size, device=device
)
input_group_offsets //= block_size

# torch reference
ref_out_scales, ref_start_cols_after_padding = torch_to_blocked_2d_K_groups(
e8m0_scales,
input_group_offsets,
)

# triton kernel
_, output_group_offsets = compute_blocked_scale_offsets_for_K_groups(
input_group_offsets
)
assert torch.equal(output_group_offsets, ref_start_cols_after_padding), (
"output scale group start offsets not equal"
)
triton_out_scales = triton_mx_block_rearrange_2d_K_groups(
e8m0_scales,
input_group_offsets,
output_group_offsets,
)
assert torch.equal(ref_out_scales, triton_out_scales), "blocked scales not equal"
Loading
Loading