Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass#3972
Conversation
Two interdependent changes that together yield ~4.3x end-to-end training throughput improvement on MI300X for DeepSeek-MoE-16B: 1. Remove redundant .t().contiguous().t() memory copies before calling triton_fp8_per_group_colwise_scales in the backward pass. The kernel already handles arbitrary strides via its stride parameters, so these full-tensor copies are unnecessary. 2. Use larger Triton autotune configs (BLOCK_SIZE=128/256, BLOCK_SIZE_ITER= 128/256) for the colwise scales kernel on AMD GPUs. With row-major input (from change 1), larger block sizes enable contiguous column access patterns, reducing grid block count by 4-8x. Benchmarked on 8x MI300X with DeepSeek-MoE-16B (EP=8, seq_len=4096): - Batch size 1: 136 TPS -> 642 TPS (4.7x) - Batch size 4: 500 TPS -> 2153 TPS (4.3x) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3972
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit b455246 with merge base 4ae435e ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Hi @lizamd! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
|
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
This is interesting. I originally added this as a temporary measure because I found that the quantization kernels were two times as fast with the tensor in col major, even with this memory layout transform. We pivoted to mxfp8 and i never got a chance to investigate further. Are you able to benchmark this change on a H100 as well? |
|
we can test on MI350 also. H100 is not tested yet, we can add it |
ok please let me know |
|
@alex-minooka @lizamd looks great! will land once CI is green. for future reference, please include microbenchmarks in the PR description (e.g., in this could you would run |
PR pytorch#3952 expanded Triton autotune configurations for MoE FP8 rowwise kernels on AMD GPUs (24-36 configs gated behind torch.version.hip). Benchmarking on MI300X reveals this causes: 1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner selecting suboptimal configs from the noisy microbenchmark results 2. Non-deterministic config selection across runs 3. No measurable improvement on Llama4 shapes vs the original single config (the PR's reported gains were vs torch.compile, not vs the original Triton config) Revert to the original single config for both atomic and reduction kernels, which is near-optimal across all tested shape families. This does NOT revert other valuable changes from pytorch#3952: - N_GROUPS added to autotune key in jagged_float8_scales.py - N_GROUPS: tl.int64 type annotation fixes The jagged_float8_scales.py configs (from PR pytorch#3972) are also preserved, as they were carefully benchmarked and provide 4.3x improvement. Benchmark results on MI300X (atomic kernel, representative shapes): | Shape | Expanded (pytorch#3952) | Single (this PR) | |-------------------|------------------|-------------------| | (128, 8192, 5120) | 10.56 ms | 10.43 ms | | (128, 5120, 8192) | 10.50 ms | 10.40 ms | | (8, 2048, 1408) | 0.068 ms | 0.072 ms | | (8, 1408, 2048) | 0.069 ms | 0.078 ms | | Cold-cache overhead| 4.4s | 1.9s |
…pes (#4024) PR #3952 expanded Triton autotune configurations for MoE FP8 rowwise kernels on AMD GPUs (24-36 configs gated behind torch.version.hip). Benchmarking on MI300X reveals this causes: 1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner selecting suboptimal configs from the noisy microbenchmark results 2. Non-deterministic config selection across runs 3. No measurable improvement on Llama4 shapes vs the original single config (the PR's reported gains were vs torch.compile, not vs the original Triton config) Revert to the original single config for both atomic and reduction kernels, which is near-optimal across all tested shape families. This does NOT revert other valuable changes from #3952: - N_GROUPS added to autotune key in jagged_float8_scales.py - N_GROUPS: tl.int64 type annotation fixes The jagged_float8_scales.py configs (from PR #3972) are also preserved, as they were carefully benchmarked and provide 4.3x improvement. Benchmark results on MI300X (atomic kernel, representative shapes): | Shape | Expanded (#3952) | Single (this PR) | |-------------------|------------------|-------------------| | (128, 8192, 5120) | 10.56 ms | 10.43 ms | | (128, 5120, 8192) | 10.50 ms | 10.40 ms | | (8, 2048, 1408) | 0.068 ms | 0.072 ms | | (8, 1408, 2048) | 0.069 ms | 0.078 ms | | Cold-cache overhead| 4.4s | 1.9s |
…se_scales to be row major to reflect the usage of this op after pytorch#3972 was merged.
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark script from upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark script from upstream. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

Summary
Builds on #3952 which expanded Triton autotune configs for AMD GPUs. This PR further optimizes the
_triton_fp8_per_group_colwise_scales_kernel(used in the backward pass of_Float8GroupedMM) with two interdependent changes:Remove redundant
.t().contiguous().t()memory copies before callingtriton_fp8_per_group_colwise_scalesin the backward pass. The kernel already handles arbitrary strides via its stride parameters, so these full-tensor copies (which convert row-major to column-major) are unnecessary.Use larger Triton autotune configs (
BLOCK_SIZE=128/256,BLOCK_SIZE_ITER=128/256,num_warps=8) for the colwise scales kernel on AMD GPUs, replacing the 16 smaller configs from Expand Triton autotune configs for MoE FP8 kernels to improve AMD GPU performance #3952 (BLOCK_SIZE=32/64). With row-major input (from change 1), larger block sizes enable contiguous column access patterns, reducing grid block count by 4-8x.These two changes are interdependent: removing the memory copy passes row-major tensors to the kernel, and the larger block sizes are optimal for row-major input. With column-major input (the old behavior), the larger block sizes actually perform worse due to strided access patterns.
Note: #3952's expanded configs for
float8_rowwise.py(the 3D transpose kernel) are unchanged — this PR only modifies thejagged_float8_scales.pyconfigs.Ablation Study
The
.t().contiguous().t()was originally added because with small block sizes, column-major input was faster. Our ablation shows neither change works alone — they must be applied together:End-to-end training with torchtitan on 8x MI300X with DeepSeek-MoE-16B (EP=8, batch=1, seq_len=4096, torch.compile enabled):
.t().contiguous().t()Why the interdependency? The colwise scales kernel iterates over rows in blocks of
BLOCK_SIZE_ITERand over columns in blocks ofBLOCK_SIZE. With smallBLOCK_SIZE(32/64), each block covers few columns so the strided access from column-major layout is tolerable, and the.t().contiguous().t()copy makes those accesses contiguous. With largeBLOCK_SIZE(128/256), each block covers many more columns — row-major input means these columns are contiguous in memory, enabling efficient wide loads. But column-major input with large blocks creates large strides that kill performance.Benchmarks
End-to-end training with torchtitan on 8x MI300X with DeepSeek-MoE-16B (EP=8, seq_len=4096, torch.compile enabled):
Profiler analysis shows
_triton_fp8_per_group_colwise_scales_kernelwas the dominant bottleneck (83% of GPU time at baseline). After optimization, kernel time dropped from 11.8s to 4.1s with 10x fewer kernel launches (86,265 → 8,671).Test plan
pytest test/prototype/moe_training/ -v🤖 Generated with Claude Code