Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/prototype/moe_training/bench_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def warmup(model, input, labels):
parser.add_argument(
"--local_batch_size",
type=int,
default=8,
default=12,
)
parser.add_argument(
"--hidden_dim",
Expand Down
14 changes: 8 additions & 6 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
_quantize_then_scaled_grouped_mm,
_to_mxfp8_then_scaled_grouped_mm,
)
from torchao.prototype.moe_training.utils import (
_to_mxfp8_per_group_colwise,
Expand Down Expand Up @@ -317,11 +318,10 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
"M,K,N", [(16640, 5120, 8192), (131072, 5120, 8192), (131072, 8192, 5120)]
)
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
from torchao.prototype.moe_training.scaled_grouped_mm import (
_MXFP8GroupedMM,
)

@pytest.mark.parametrize("use_triton_for_dim0_cast", (True, False))
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(
M, K, N, num_experts, use_triton_for_dim0_cast
):
block_size = 32
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
w = torch.randn(
Expand All @@ -340,7 +340,9 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
)

# Forward
out = _MXFP8GroupedMM.apply(x, w_t, offs, block_size, torch.bfloat16)
out = _to_mxfp8_then_scaled_grouped_mm(
x, w_t, offs, block_size, torch.bfloat16, use_triton_for_dim0_cast
)
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
sqnr = compute_error(ref_out, out)
min_sqnr = 27.0
Expand Down
67 changes: 42 additions & 25 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

import logging
from functools import partial
from typing import Optional

import torch
Expand Down Expand Up @@ -34,6 +33,7 @@
ScaleCalculationMode,
)
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,15 +79,6 @@ def _quantize_then_scaled_grouped_mm(
raise ValueError(f"Unsupported scaling type {scaling_type}")


# Aliases for convenience/clarity
_to_mxfp8_then_scaled_grouped_mm = partial(
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8
)
_to_fp8_rowwise_then_scaled_grouped_mm = partial(
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE
)


class _Float8GroupedMM(torch.autograd.Function):
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""

Expand Down Expand Up @@ -304,6 +295,7 @@ def forward(
block_size: int = 32,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
emulated: bool = False,
use_triton_for_dim0_cast: bool = False,
) -> torch.Tensor:
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
assert A.ndim == 2, "A must be 2D"
Expand All @@ -313,17 +305,28 @@ def forward(

# A_data shape: (M, K)
# A_scale shape: (M, K//block_size)
A_data, A_scale = triton_to_mxfp8_dim0(
A,
inner_block_size=block_size,
)

# B_data shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
B_data, B_scales = triton_to_mxfp8_dim0(
B_t.transpose(-2, -1),
inner_block_size=block_size,
)
if use_triton_for_dim0_cast:
A_data, A_scale = triton_to_mxfp8_dim0(
A,
inner_block_size=block_size,
)
# B_data shape: (E, N, K)
# B_scale shape: (E, N, K//block_size)
B_data, B_scales = triton_to_mxfp8_dim0(
B_t.transpose(-2, -1),
inner_block_size=block_size,
)
else:
A_scale, A_data = to_mx(
A,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)
B_scales, B_data = to_mx(
B_t.transpose(-2, -1),
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Convert scales to blocked format for 2d-3d grouped mm
_, blocked_scales_group_offsets_2d3d = (
Expand Down Expand Up @@ -351,19 +354,28 @@ def forward(
ctx.block_size = block_size
ctx.out_dtype = out_dtype
ctx.emulated = emulated
ctx.use_triton_for_dim0_cast = use_triton_for_dim0_cast
return out

@staticmethod
def backward(ctx, grad_out: torch.Tensor):
A, B_t, offs, blocked_scales_group_offsets_2d3d = ctx.saved_tensors
block_size = ctx.block_size
out_dtype = ctx.out_dtype
use_triton_for_dim0_cast = ctx.use_triton_for_dim0_cast

# grad_out_data shape: (M, N)
# grad_out_scale shape: (M, N//block_size)
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
grad_out, inner_block_size=block_size
)
if use_triton_for_dim0_cast:
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
grad_out, inner_block_size=block_size
)
else:
grad_out_scale, grad_out_data = to_mx(
grad_out,
elem_dtype=torch.float8_e4m3fn,
block_size=block_size,
)

# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)
# (E, K, N) -> (E, N, K)
Expand Down Expand Up @@ -449,7 +461,7 @@ def backward(ctx, grad_out: torch.Tensor):
)
# grad_B_t shape = (E,K,N)
grad_B_t = grad_B.transpose(-2, -1)
return grad_A, grad_B_t, None, None, None
return grad_A, grad_B_t, None, None, None, None


def _to_mxfp8_dim1_3d(
Expand Down Expand Up @@ -659,3 +671,8 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(

def round_up(x, y):
return ((x + y - 1) // y) * y


# Aliases for convenience/clarity
_to_mxfp8_then_scaled_grouped_mm = _MXFP8GroupedMM.apply
_to_fp8_rowwise_then_scaled_grouped_mm = _Float8GroupedMM.apply
Loading