diff --git a/benchmarks/prototype/moe_training/bench_moe_layer.py b/benchmarks/prototype/moe_training/bench_moe_layer.py index cee1863c67..16e8c19c0a 100644 --- a/benchmarks/prototype/moe_training/bench_moe_layer.py +++ b/benchmarks/prototype/moe_training/bench_moe_layer.py @@ -205,7 +205,7 @@ def warmup(model, input, labels): parser.add_argument( "--local_batch_size", type=int, - default=8, + default=12, ) parser.add_argument( "--hidden_dim", diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index a31df4d435..8b84e8abd8 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -33,6 +33,7 @@ _emulated_mxfp8_scaled_grouped_mm_2d_2d, _emulated_mxfp8_scaled_grouped_mm_2d_3d, _quantize_then_scaled_grouped_mm, + _to_mxfp8_then_scaled_grouped_mm, ) from torchao.prototype.moe_training.utils import ( _to_mxfp8_per_group_colwise, @@ -317,11 +318,10 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): "M,K,N", [(16640, 5120, 8192), (131072, 5120, 8192), (131072, 8192, 5120)] ) @pytest.mark.parametrize("num_experts", (2, 4, 8, 16)) -def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): - from torchao.prototype.moe_training.scaled_grouped_mm import ( - _MXFP8GroupedMM, - ) - +@pytest.mark.parametrize("use_triton_for_dim0_cast", (True, False)) +def test_mxfp8_grouped_gemm_with_dq_fwd_bwd( + M, K, N, num_experts, use_triton_for_dim0_cast +): block_size = 32 x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) w = torch.randn( @@ -340,7 +340,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts): ) # Forward - out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16) + out = _to_mxfp8_then_scaled_grouped_mm( + x, w_t, offs, block_size, torch.bfloat16, use_triton_for_dim0_cast + ) ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16) sqnr = compute_error(ref_out, out) min_sqnr = 27.0 diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index afe7babc66..2bb3166d16 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import logging -from functools import partial from typing import Optional import torch @@ -34,6 +33,7 @@ ScaleCalculationMode, ) from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 +from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper logger: logging.Logger = logging.getLogger(__name__) @@ -79,15 +79,6 @@ def _quantize_then_scaled_grouped_mm( raise ValueError(f"Unsupported scaling type {scaling_type}") -# Aliases for convenience/clarity -_to_mxfp8_then_scaled_grouped_mm = partial( - _quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8 -) -_to_fp8_rowwise_then_scaled_grouped_mm = partial( - _quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE -) - - class _Float8GroupedMM(torch.autograd.Function): """Differentiable implementation of grouped GEMM with dynamic float8 quantization.""" @@ -304,6 +295,7 @@ def forward( block_size: int = 32, out_dtype: Optional[torch.dtype] = torch.bfloat16, emulated: bool = False, + use_triton_for_dim0_cast: bool = False, ) -> torch.Tensor: # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D. assert A.ndim == 2, "A must be 2D" @@ -313,17 +305,28 @@ def forward( # A_data shape: (M, K) # A_scale shape: (M, K//block_size) - A_data, A_scale = triton_to_mxfp8_dim0( - A, - inner_block_size=block_size, - ) - - # B_data shape: (E, N, K) - # B_scale shape: (E, N, K//block_size) - B_data, B_scales = triton_to_mxfp8_dim0( - B_t.transpose(-2, -1), - inner_block_size=block_size, - ) + if use_triton_for_dim0_cast: + A_data, A_scale = triton_to_mxfp8_dim0( + A, + inner_block_size=block_size, + ) + # B_data shape: (E, N, K) + # B_scale shape: (E, N, K//block_size) + B_data, B_scales = triton_to_mxfp8_dim0( + B_t.transpose(-2, -1), + inner_block_size=block_size, + ) + else: + A_scale, A_data = to_mx( + A, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + B_scales, B_data = to_mx( + B_t.transpose(-2, -1), + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) # Convert scales to blocked format for 2d-3d grouped mm _, blocked_scales_group_offsets_2d3d = ( @@ -351,6 +354,7 @@ def forward( ctx.block_size = block_size ctx.out_dtype = out_dtype ctx.emulated = emulated + ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast return out @staticmethod @@ -358,12 +362,20 @@ def backward(ctx, grad_out: torch.Tensor): A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors block_size = ctx.block_size out_dtype = ctx.out_dtype + use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast # grad_out_data shape: (M, N) # grad_out_scale shape: (M, N//block_size) - grad_out_data, grad_out_scale = triton_to_mxfp8_dim0( - grad_out, inner_block_size=block_size - ) + if use_triton_for_dim0_cast: + grad_out_data, grad_out_scale = triton_to_mxfp8_dim0( + grad_out, inner_block_size=block_size + ) + else: + grad_out_scale, grad_out_data = to_mx( + grad_out, + elem_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) # Quantize 3d expert weights along N (contraction dimension for next grouped gemm) # (E, K, N) -> (E, N, K) @@ -449,7 +461,7 @@ def backward(ctx, grad_out: torch.Tensor): ) # grad_B_t shape = (E,K,N) grad_B_t = grad_B.transpose(-2, -1) - return grad_A, grad_B_t, None, None, None + return grad_A, grad_B_t, None, None, None, None def _to_mxfp8_dim1_3d( @@ -659,3 +671,8 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d( def round_up(x, y): return ((x + y - 1) // y) * y + + +# Aliases for convenience/clarity +_to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM.apply +_to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM.apply