Skip to content

[moe training] refactor configs, recipes; support converting linears + grouped gemms in a single quantize_() call#3862

Merged
danielvegamyhre merged 1 commit intomainfrom
refactor-feb10
Feb 20, 2026
Merged

[moe training] refactor configs, recipes; support converting linears + grouped gemms in a single quantize_() call#3862
danielvegamyhre merged 1 commit intomainfrom
refactor-feb10

Conversation

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre commented Feb 11, 2026

Summary

To prepare the MoE training prototype to stable we need the dev experience to be consistent with the rest of torchao, namely the fp8/mxfp8 linear training.

This PR:

  • Splits FP8 grouped mm and MXFP8 grouped mm related code into separate files to organize by derived dtype
  • Update APIs and configs to align with the patterns of MXLinearConfig, MXLinear.
  • Add support for converting both linears and grouped gemms in a single quantize_() call

These files contain the important changes, everything else is peripheral:

  • torchao/prototype/moe_training/config.py
  • torchao/prototype/moe_training/conversion_utils.py
  • torchao/prototype/moe_training/mxfp8_grouped_mm.py
  • torchao/prototype/moe_training/fp8_grouped_mm.py
  • torchao/prototype/moe_training/tensor.py
  • torchao/quantization/quant_api.py

New model conversion API

    model = MoEModel()

    config = FqnToConfig(
        fqn_to_config=OrderedDict(
            [
                (
                    "re:.*experts.*",
                    MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL),
                ),
                (
                    "re:^(pre_moe|post_moe)$",
                    MXLinearConfig.from_recipe_name(MXLinearRecipeName.MXFP8_CUBLAS_RCEIL),
                ),
            ]
        )
    )

    quantize_(model, config, filter_fn=None)

New recipes

class FP8GroupedMMRecipe(Enum):
    """FP8 recipes for grouped matrix multiplication."""
    FP8_ROWWISE = "fp8_rowwise"


class MXFP8GroupedMMRecipe(Enum):
    """MXFP8 recipes for grouped matrix multiplication."""
    # TODO: add floor variants
    MXFP8_RCEIL = "mxfp8_rceil"
    MXFP8_RCEIL_WGRAD_WITH_HP = "mxfp8_rceil_wgrad_with_hp"
    MXFP8_EMULATED_RCEIL = "mxfp8_emulated_rceil"

New configs

class GroupedMMConfig(AOBaseConfig):
    """Base configuration for grouped matrix multiplication. Not intended to be used directly."""
    pass


@dataclass
class FP8GroupedMMConfig(GroupedMMConfig):
    # Output dtype for the FP8 grouped GEMMs.
    out_dtype: Optional[torch.dtype] = torch.bfloat16

    @classmethod
    def from_recipe(
        cls,
        recipe: FP8GroupedMMRecipe,
    ) -> "FP8GroupedMMConfig":
        ...

@dataclass
class MXFP8GroupedMMConfig(GroupedMMConfig):
    # AUTO = Use best supported kernel for quantization ops and GEMMs (CUDA and Triton for quantizatoin, CUTLASS for MXFP8 grouped GEM
    # EMULATED = Hardware agnostic mode that can be used for debugging or development on non-SM100 machines.
    #            Uses PyTorch native quantization ops, then dequantizes and uses emulated MXFP8 grouped GEMMs implemented in PyTorch.
    #            Not recommended for performance.
    kernel_preference: KernelPreference = KernelPreference.AUTO

    # Output dtype for the MXFP8 grouped GEMMs.
    out_dtype: Optional[torch.dtype] = torch.bfloat16

    # Whether to compute the gradient of the weights in high precision (True) or use MXFP8 (False).
    wgrad_with_hp: bool = False

    # Rounding mode to use when calculating the e8m0 scale factors.
    scale_calculation_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL

    @classmethod
    def from_recipe(
        cls,
        recipe: MXFP8GroupedMMRecipe,
    ) -> "MXFP8GroupedMMConfig":
         ...

Testing

Added new test:

  • pytest test/prototype/moe_training/test_fqn_to_config.py

Existing tests passing:

  • ./test/prototype/moe_training/test_everything.sh

Detailed change list (optional read)

Configuration Refactoring

Created new config.py file containing all MoE training configurations:

  • Added FP8GroupedMMRecipe enum with FP8_ROWWISE recipe
  • Added MXFP8GroupedMMRecipe enum with recipes: MXFP8_RCEIL, MXFP8_RCEIL_WGRAD_WITH_HP, MXFP8_EMULATED_RCEIL
  • Added GroupedMMConfig base class for type abstraction
  • Added FP8GroupedMMConfig dataclass with out_dtype field (defaults to bf16)
  • Added MXFP8GroupedMMConfig dataclass with fields: kernel_preference, out_dtype, wgrad_with_hp, scale_calculation_mode
  • Implemented .from_recipe() factory methods for both config classes
  • Added eq and hash methods to MXFP8GroupedMMConfig (context: is needed for torch._dynamo.nonstrict_trace mode, which is in turn needed to support accept pre-quantized MXTensor inputs and regular torch.Tensor outputs and backward input gradients)

Code Organization

Split grouped GEMM autograd func implementations:

  • Renamed scaled_grouped_mm.py → mxfp8_grouped_mm.py (contains MXFP8-specific implementation)
  • Created new fp8_grouped_mm.py (contains FP8 rowwise implementation extracted from previous file)

Refactored conversion_utils.py:

  • Removed MoEScalingType enum (replaced by recipe enums in config.py)
  • Removed MoETrainingConfig class (replaced by config classes in config.py)
  • Added target_parameter_name parameter for compatibility with FqnToConfig per-parameter quantization

Updates to tensor.py:

  • Added _quantize_then_scaled_grouped_mm dispatcher function that routes to appropriate implementation based on config type
  • Dispatcher checks if config is FP8GroupedMMConfig or MXFP8GroupedMMConfig and calls corresponding function

Support linear and grouped mm conversion in single quantize_() call:

  • Registered FP8GroupedMMConfig and MXFP8GroupedMMConfig in CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS in quant_api.py
  • Enables using FqnToConfig to apply different quantization configs to different modules/parameters in a single quantize_() call
  • Created comprehensive test suite (test_fqn_to_config.py):
  • Tests regex pattern matching for FQN-based configuration
  • Tests selective quantization (experts only, dense only, specific parameters)
  • Tests different recipe configurations on different parameters

API Changes:

  • Deprecated API: MoETrainingConfig and MoEScalingType
  • New API: Use FP8GroupedMMConfig.from_recipe(FP8GroupedMMRecipe.FP8_ROWWISE) or MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL)
  • New exports in init.py: Added _to_fp8_rowwise_then_scaled_grouped_mm to public API

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Feb 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3862

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6df211d with merge base 920c502 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 11, 2026
@danielvegamyhre danielvegamyhre force-pushed the refactor-feb10 branch 4 times, most recently from 326d30b to 62159aa Compare February 11, 2026 22:33
@danielvegamyhre danielvegamyhre added mx module: bc-breaking Use this tag if this PR breaks backward compatibility moe topic: new feature Use this tag if this PR adds a new feature topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) and removed module: bc-breaking Use this tag if this PR breaks backward compatibility topic: new feature Use this tag if this PR adds a new feature labels Feb 11, 2026
@danielvegamyhre danielvegamyhre force-pushed the refactor-feb10 branch 11 times, most recently from c7556e0 to 49ab389 Compare February 12, 2026 02:52
# Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check
# TODO see https://github.com/pytorch/ao/issues/3252 for more details
from torchao.prototype.moe_training.config import (
FP8GroupedMMConfig,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm prototype in main api

cc @jerryzh168 is that okay?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't do this, there are some imports here that started here and moved to prototype, but we shouldn't add new ones here

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 any thoughts on how to allowlist the training configs to use quantize_() + FqnToConfig, if we are hesitant about importing from prototype here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feels like we should just finish this BE task: #3252 so that all configs are supported. shouldn't be too complicated with claude

cc @jcaip as well

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also before this, we could also drop some of the configs that's moved to prototype

Copy link
Copy Markdown
Contributor

@jcaip jcaip Feb 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was planning on taking a stab #3252 tonight, will tag u guys in the PRs

If you have a list of configs that have been moved to prototype @jerryzh168, please lmk.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah check this: #2752, we can remove code for prototype configs, and also remove code for v1 config (basically everything implemented with AQT), and in the end we can remove AQT

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good, thanks @jcaip lmk when that's ready and i'll rebase on top of it

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @danielvegamyhre if you rebase on top of #3894 you should no longer need to import for CUSTOM_PARAM_QUANTIZATION_SUPPORTED_CONFIGS

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rebased on top of #3894 now that it has landed, and removed the prototype import in quant_api.py.

planning to land this once CI is green

@danielvegamyhre danielvegamyhre merged commit 4a42d32 into main Feb 20, 2026
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants