Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
bench_fwd_microseconds,
profile_fwd_bwd,
)
from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs

Expand Down Expand Up @@ -158,7 +158,7 @@ def run_experiment(

# fwd_bwd scaled benchmark + profiling
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
_scaled_grouped_mm,
_quantize_then_scaled_grouped_mm,
A,
B_t,
offs,
Expand All @@ -169,7 +169,7 @@ def run_experiment(
)
if args.profile:
profile_fwd_bwd(
_scaled_grouped_mm,
_quantize_then_scaled_grouped_mm,
A,
B_t,
offs,
Expand All @@ -190,7 +190,7 @@ def run_experiment(
fullgraph=True,
)
scaled_fwd_us = bench_fwd_microseconds(
_scaled_grouped_mm,
_quantize_then_scaled_grouped_mm,
A,
B_t,
offs,
Expand Down
8 changes: 4 additions & 4 deletions test/prototype/moe_training/test_scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torchao.prototype.moe_training.scaled_grouped_mm import (
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
_scaled_grouped_mm,
_quantize_then_scaled_grouped_mm,
)
from torchao.prototype.moe_training.utils import (
_to_mxfp8_per_group_colwise,
Expand Down Expand Up @@ -73,7 +73,7 @@ def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
b_t = b.contiguous().transpose(-2, -1).requires_grad_(True)

# Compute output.
out = _scaled_grouped_mm(
out = _quantize_then_scaled_grouped_mm(
a,
b_t,
offs=offs,
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k):

# Compute output.
with pytest.raises(AssertionError):
_scaled_grouped_mm(
_quantize_then_scaled_grouped_mm(
a,
b_t,
offs=offs,
Expand Down Expand Up @@ -199,7 +199,7 @@ def compute_reference_forward(
result_list.append(result[start : offs_cpu[i]])
start = offs_cpu[i]

# Validate each actual result group from the _scaled_grouped_mm is equal to:
# Validate each actual result group from the _quantize_then_scaled_grouped_mm is equal to:
# 1. A manual _scaled_mm for the group.
# 2. A matmul_with_hp_or_float8_args for the group (which is differentiable, and thus used to validate gradients).
outputs = []
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/moe_training/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ This prototype provides:
import torch
from torch.nn import functional as F
from torchao.prototype.moe_training import (
_scaled_grouped_mm as torchao_scaled_grouped_mm
_quantize_then_scaled_grouped_mm as torchao_scaled_grouped_mm
)
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
from torchao.prototype.moe_training.utils import generate_jagged_offs
Expand Down
6 changes: 4 additions & 2 deletions torchao/prototype/moe_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from torchao.prototype.moe_training.scaled_grouped_mm import _scaled_grouped_mm
from torchao.prototype.moe_training.scaled_grouped_mm import (
_quantize_then_scaled_grouped_mm,
)

__all__ = ["_scaled_grouped_mm"]
__all__ = ["_quantize_then_scaled_grouped_mm"]
20 changes: 15 additions & 5 deletions torchao/prototype/moe_training/scaled_grouped_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import logging
from functools import partial
from typing import Optional

import torch
Expand Down Expand Up @@ -38,15 +39,15 @@
logger: logging.Logger = logging.getLogger(__name__)


def _scaled_grouped_mm(
def _quantize_then_scaled_grouped_mm(
A: torch.Tensor,
B_t: torch.Tensor,
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE,
) -> torch.Tensor:
"""
This function performs dynamic float8 quantization with row-wise scaling
This function performs dynamic quantization with the given recipe
on the input tensors A and B, then performs a scaled grouped GEMM and returns the results.

Args:
Expand Down Expand Up @@ -78,6 +79,15 @@ def _scaled_grouped_mm(
raise ValueError(f"Unsupported scaling type {scaling_type}")


# Aliases for convenience/clarity
_to_mxfp8_then_scaled_grouped_mm = partial(
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.MXFP8
)
_to_fp8_rowwise_then_scaled_grouped_mm = partial(
_quantize_then_scaled_grouped_mm, scaling_type=MoEScalingType.FP8_ROWWISE
)


class _Float8GroupedMM(torch.autograd.Function):
"""Differentiable implementation of grouped GEMM with dynamic float8 quantization."""

Expand All @@ -89,7 +99,7 @@ def forward(
offs: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.bfloat16,
) -> torch.Tensor:
# torchao _scaled_grouped_mm only supports A=2D|3D and B=3D.
# torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
assert B_t.ndim == 3, "B must be 3D"

Expand All @@ -113,7 +123,7 @@ def forward(

# Assert A and B dims are compatible for a scaled grouped GEMM.
assert A.size(-1) == B_t.size(-2), (
f"shape {A.shape} and {B_t.shape} are not compatible for _scaled_grouped_mm"
f"shape {A.shape} and {B_t.shape} are not compatible for _quantize_then_scaled_grouped_mm"
)

# The left operand in the scaled grouped GEMM must be row-major due to hardware requirements.
Expand Down Expand Up @@ -295,7 +305,7 @@ def forward(
out_dtype: Optional[torch.dtype] = torch.bfloat16,
emulated: bool = False,
) -> torch.Tensor:
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
# torchao _quantize_then_scaled_grouped_mm only supports A=2D and B=3D.
assert A.ndim == 2, "A must be 2D"
assert B_t.ndim == 3, "B must be 3D"
assert block_size == 32, "Only block_size=32 is supported"
Expand Down
8 changes: 4 additions & 4 deletions torchao/prototype/moe_training/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import MixedPrecisionPolicy

from torchao.prototype.moe_training import _scaled_grouped_mm
from torchao.prototype.moe_training import _quantize_then_scaled_grouped_mm
from torchao.prototype.moe_training.conversion_utils import MoEScalingType

logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -39,7 +39,7 @@ class ScaledGroupedMMTensor(torch.Tensor):
"""
ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor
and overrides the torch._grouped_mm op by dispatching to the
differentiable _scaled_grouped_mm autograd function.
differentiable _quantize_then_scaled_grouped_mm autograd function.
"""

scaling_type: MoEScalingType = MoEScalingType.FP8_ROWWISE
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(

@classmethod
def __torch_function__(cls, func, types, args, kwargs={}):
# override the grouped mm op to use the differentiable _scaled_grouped_mm
# override the grouped mm op to use the differentiable _quantize_then_scaled_grouped_mm
if func.__name__ == cls.grouped_mm_func_name:
# Use torchao scaled grouped mm with dynamic quant for
# "2d x 3d with offsets" case (used for routed experts).
Expand All @@ -99,7 +99,7 @@ def __torch_function__(cls, func, types, args, kwargs={}):
has_offs = kwargs.get(cls.offs_arg_name) is not None
other_args = args[2:]
if A_is_2d and B_is_2d_or_3d and has_offs:
return _scaled_grouped_mm(
return _quantize_then_scaled_grouped_mm(
A,
B,
*other_args,
Expand Down
Loading