From a1473c4d32a58af6dc24af5f7636b492f2b54bf8 Mon Sep 17 00:00:00 2001 From: Daniel Vega-Myhre Date: Tue, 21 Oct 2025 10:11:46 -0700 Subject: [PATCH] [moe training] change api _scaled_grouped_mm -> _quantize_then_scaled_grouped_mm --- .../benchmark_scaled_grouped_mm_dq.py | 8 ++++---- .../moe_training/test_scaled_grouped_mm.py | 8 ++++---- torchao/prototype/moe_training/README.md | 2 +- torchao/prototype/moe_training/__init__.py | 6 ++++-- .../moe_training/scaled_grouped_mm.py | 20 ++++++++++++++----- torchao/prototype/moe_training/tensor.py | 8 ++++---- 6 files changed, 32 insertions(+), 20 deletions(-) diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index a28d981e8a..5b4177c564 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -19,7 +19,7 @@ bench_fwd_microseconds, profile_fwd_bwd, ) -from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.utils import generate_jagged_offs @@ -158,7 +158,7 @@ def run_experiment( # fwd_bwd scaled benchmark + profiling scaled_fwd_bwd_us = bench_fwd_bwd_microseconds( - _scaled_grouped_mm, + _quantize_then_scaled_grouped_mm, A, B_t, offs, @@ -169,7 +169,7 @@ def run_experiment( ) if args.profile: profile_fwd_bwd( - _scaled_grouped_mm, + _quantize_then_scaled_grouped_mm, A, B_t, offs, @@ -190,7 +190,7 @@ def run_experiment( fullgraph=True, ) scaled_fwd_us = bench_fwd_microseconds( - _scaled_grouped_mm, + _quantize_then_scaled_grouped_mm, A, B_t, offs, diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index e382351a12..a31df4d435 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -32,7 +32,7 @@ from torchao.prototype.moe_training.scaled_grouped_mm import ( _emulated_mxfp8_scaled_grouped_mm_2d_2d, _emulated_mxfp8_scaled_grouped_mm_2d_3d, - _scaled_grouped_mm, + _quantize_then_scaled_grouped_mm, ) from torchao.prototype.moe_training.utils import ( _to_mxfp8_per_group_colwise, @@ -73,7 +73,7 @@ def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups): b_t = b.contiguous().transpose(-2, -1).requires_grad_(True) # Compute output. - out = _scaled_grouped_mm( + out = _quantize_then_scaled_grouped_mm( a, b_t, offs=offs, @@ -142,7 +142,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): # Compute output. with pytest.raises(AssertionError): - _scaled_grouped_mm( + _quantize_then_scaled_grouped_mm( a, b_t, offs=offs, @@ -199,7 +199,7 @@ def compute_reference_forward( result_list.append(result[start : offs_cpu[i]]) start = offs_cpu[i] - # Validate each actual result group from the _scaled_grouped_mm is equal to: + # Validate each actual result group from the _quantize_then_scaled_grouped_mm is equal to: # 1. A manual _scaled_mm for the group. # 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients). outputs = [] diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index a2d8d03a79..14d571844e 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -27,7 +27,7 @@ This prototype provides: import torch from torch.nn import functional as F from torchao.prototype.moe_training import ( - _scaled_grouped_mm as torchao_scaled_grouped_mm + _quantize_then_scaled_grouped_mm as torchao_scaled_grouped_mm ) from torchao.prototype.moe_training.conversion_utils import MoEScalingType from torchao.prototype.moe_training.utils import generate_jagged_offs diff --git a/torchao/prototype/moe_training/__init__.py b/torchao/prototype/moe_training/__init__.py index 8118193aff..d0832a8e87 100644 --- a/torchao/prototype/moe_training/__init__.py +++ b/torchao/prototype/moe_training/__init__.py @@ -1,3 +1,5 @@ -from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm +from torchao.prototype.moe_training.scaled_grouped_mm import ( + _quantize_then_scaled_grouped_mm, +) -__all__ = ["_scaled_grouped_mm"] +__all__ = ["_quantize_then_scaled_grouped_mm"] diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index ae8a0bc96d..afe7babc66 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import logging +from functools import partial from typing import Optional import torch @@ -38,7 +39,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def _scaled_grouped_mm( +def _quantize_then_scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, offs: Optional[torch.Tensor] = None, @@ -46,7 +47,7 @@ def _scaled_grouped_mm( scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE, ) -> torch.Tensor: """ - This function performs dynamic float8 quantization with row-wise scaling + This function performs dynamic quantization with the given recipe on the input tensors A and B, then performs a scaled grouped GEMM and returns the results. Args: @@ -78,6 +79,15 @@ def _scaled_grouped_mm( raise ValueError(f"Unsupported scaling type {scaling_type}") +# Aliases for convenience/clarity +_to_mxfp8_then_scaled_grouped_mm = partial( + _quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8 +) +_to_fp8_rowwise_then_scaled_grouped_mm = partial( + _quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE +) + + class _Float8GroupedMM(torch.autograd.Function): """Differentiable implementation of grouped GEMM with dynamic float8 quantization.""" @@ -89,7 +99,7 @@ def forward( offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D|3D and B=3D. + # torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D. assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D" assert B_t.ndim == 3, "B must be 3D" @@ -113,7 +123,7 @@ def forward( # Assert A and B dims are compatible for a scaled grouped GEMM. assert A.size(-1) == B_t.size(-2), ( - f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm" + f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm" ) # The left operand in the scaled grouped GEMM must be row-major due to hardware requirements. @@ -295,7 +305,7 @@ def forward( out_dtype: Optional[torch.dtype] = torch.bfloat16, emulated: bool = False, ) -> torch.Tensor: - # torchao _scaled_grouped_mm only supports A=2D and B=3D. + # torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D. assert A.ndim == 2, "A must be 2D" assert B_t.ndim == 3, "B must be 3D" assert block_size == 32, "Only block_size=32 is supported" diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 0bbbda850e..287dfdc5d9 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -15,7 +15,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import MixedPrecisionPolicy -from torchao.prototype.moe_training import _scaled_grouped_mm +from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm from torchao.prototype.moe_training.conversion_utils import MoEScalingType logger: logging.Logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ class ScaledGroupedMMTensor(torch.Tensor): """ ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor and overrides the torch._grouped_mm op by dispatching to the - differentiable _scaled_grouped_mm autograd function. + differentiable _quantize_then_scaled_grouped_mm autograd function. """ scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE @@ -77,7 +77,7 @@ def __init__( @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - # override the grouped mm op to use the differentiable _scaled_grouped_mm + # override the grouped mm op to use the differentiable _quantize_then_scaled_grouped_mm if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for # "2d x 3d with offsets" case (used for routed experts). @@ -99,7 +99,7 @@ def __torch_function__(cls, func, types, args, kwargs={}): has_offs = kwargs.get(cls.offs_arg_name) is not None other_args = args[2:] if A_is_2d and B_is_2d_or_3d and has_offs: - return _scaled_grouped_mm( + return _quantize_then_scaled_grouped_mm( A, B, *other_args,