Skip to content

[mxfp8 moe training] integrate cuda kernel for 'groups along M scale blocked layout'#3556

Merged
danielvegamyhre merged 1 commit intomainfrom
dec29
Dec 31, 2025
Merged

[mxfp8 moe training] integrate cuda kernel for 'groups along M scale blocked layout'#3556
danielvegamyhre merged 1 commit intomainfrom
dec29

Conversation

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre commented Dec 29, 2025

Summary

  • integrate new cuda kernel for per group scale conversion to blocked format into mxfp8 grouped mm autograd func

Tests

  • pytest test/prototype/moe_training/test_scaled_grouped_mm.py

Benchmarks for autograd func fwd + bwd (dynamic mxfp8 quantization + mxfp8 grouped GEMM)

This change helps across all shapes, but especially for the smaller dsv3 model (16b) where the quantization kernels are a larger % of overall runtime

Before:

M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8           32520.3              19518.3   1.666x                         10522.7           5773.22  1.823x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8           32451.6              19549.1   1.66x                          10116.2           5743.09  1.761x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8           32233.4              19376.1   1.664x                         11167.8           5711.9   1.955x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8           31674.5              19295.2   1.642x                         10106.9           5474.24  1.846x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8            6416.42              6939.65  0.925x                          1834.88          2022.42  0.907x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8            6320.22              6224.9   1.015x                          1658.82          1814.56  0.914x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8            5755.81              6218.88  0.926x                          2026.56          1820.61  1.113x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8            6302.72              5334.99  1.181x                          1810.38          1610.78  1.124x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8           11666.4               9886.19  1.18x                           3840.4           2917.47  1.316x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8           11647.9               9666.58  1.205x                          3779.55          2842.21  1.33x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8           11309.1               8625.25  1.311x                          3816.94          2634.32  1.449x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8           11490.9               8418.27  1.365x                          3389.47          2543.52  1.333x

After:

M,N,K,G                  recipe                  bf16_fwd_bwd_us    scaled_fwd_bwd_us  scaled_fwd_bwd_speedup      bf16_fwd_us    scaled_fwd_us  scaled_fwd_speedup
-----------------------  --------------------  -----------------  -------------------  ------------------------  -------------  ---------------  --------------------
(128000, 8192, 5120, 1)  MoEScalingType.MXFP8           32323.8              18576.4   1.74x                          10221.6           5311.42  1.924x
(128000, 8192, 5120, 2)  MoEScalingType.MXFP8           31286.3              18587.6   1.683x                         10188.8           5366.29  1.899x
(128000, 8192, 5120, 4)  MoEScalingType.MXFP8           35184.9              19145.8   1.838x                         10301.5           5503.01  1.872x
(128000, 8192, 5120, 8)  MoEScalingType.MXFP8           33045.6              19265.6   1.715x                         10010.7           5363.28  1.867x
(128000, 1536, 5120, 1)  MoEScalingType.MXFP8            6279.46              5737.41  1.094x                          1532.86          1470.59  1.042x
(128000, 1536, 5120, 2)  MoEScalingType.MXFP8            6260.86              5632     1.112x                          2027.46          1480.7   1.369x
(128000, 1536, 5120, 4)  MoEScalingType.MXFP8            6371.36              5297.22  1.203x                          1972.29          1498.18  1.316x
(128000, 1536, 5120, 8)  MoEScalingType.MXFP8            6428.58              5321.2   1.208x                          1761.17          1463.33  1.204x
(128000, 2048, 7168, 1)  MoEScalingType.MXFP8           11590.7               8782.91  1.32x                           3299.3           2417.25  1.365x
(128000, 2048, 7168, 2)  MoEScalingType.MXFP8           11312.1               8810.46  1.284x                          2870.37          2473.54  1.16x
(128000, 2048, 7168, 4)  MoEScalingType.MXFP8           11178.1               8475.62  1.319x                          3399.66          2412.51  1.409x
(128000, 2048, 7168, 8)  MoEScalingType.MXFP8           11716.6               8387.01  1.397x                          3316.77          2420.77  1.37x

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Dec 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3556

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 29, 2025
@danielvegamyhre danielvegamyhre added mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) moe labels Dec 29, 2025
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/92 branch 4 times, most recently from 7a12730 to 5bfa927 Compare December 31, 2025 17:10
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/92 to main December 31, 2025 17:19
@danielvegamyhre danielvegamyhre merged commit 8bb433e into main Dec 31, 2025
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants