Skip to content

Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass#3972

Merged
danielvegamyhre merged 1 commit intopytorch:mainfrom
lizamd:optimize-fp8-colwise-backward
Mar 2, 2026
Merged

Optimize FP8 colwise scales kernel for AMD GPUs in MoE backward pass#3972
danielvegamyhre merged 1 commit intopytorch:mainfrom
lizamd:optimize-fp8-colwise-backward

Conversation

@lizamd
Copy link
Copy Markdown
Contributor

@lizamd lizamd commented Mar 2, 2026

Summary

Builds on #3952 which expanded Triton autotune configs for AMD GPUs. This PR further optimizes the _triton_fp8_per_group_colwise_scales_kernel (used in the backward pass of _Float8GroupedMM) with two interdependent changes:

  1. Remove redundant .t().contiguous().t() memory copies before calling triton_fp8_per_group_colwise_scales in the backward pass. The kernel already handles arbitrary strides via its stride parameters, so these full-tensor copies (which convert row-major to column-major) are unnecessary.

  2. Use larger Triton autotune configs (BLOCK_SIZE=128/256, BLOCK_SIZE_ITER=128/256, num_warps=8) for the colwise scales kernel on AMD GPUs, replacing the 16 smaller configs from Expand Triton autotune configs for MoE FP8 kernels to improve AMD GPU performance #3952 (BLOCK_SIZE=32/64). With row-major input (from change 1), larger block sizes enable contiguous column access patterns, reducing grid block count by 4-8x.

These two changes are interdependent: removing the memory copy passes row-major tensors to the kernel, and the larger block sizes are optimal for row-major input. With column-major input (the old behavior), the larger block sizes actually perform worse due to strided access patterns.

Note: #3952's expanded configs for float8_rowwise.py (the 3D transpose kernel) are unchanged — this PR only modifies the jagged_float8_scales.py configs.

Ablation Study

The .t().contiguous().t() was originally added because with small block sizes, column-major input was faster. Our ablation shows neither change works alone — they must be applied together:

End-to-end training with torchtitan on 8x MI300X with DeepSeek-MoE-16B (EP=8, batch=1, seq_len=4096, torch.compile enabled):

.t().contiguous().t() BLOCK_SIZE Input Layout TPS Notes
Kept 32/64 (#3952) col-major 136 Baseline
Removed 32/64 (#3952) row-major 137 No improvement — small blocks still do strided column access
Kept 128/256 (this PR) col-major 127 Worse — large blocks + col-major = bad stride pattern
Removed 128/256 (this PR) row-major 642 4.7x — both changes together

Why the interdependency? The colwise scales kernel iterates over rows in blocks of BLOCK_SIZE_ITER and over columns in blocks of BLOCK_SIZE. With small BLOCK_SIZE (32/64), each block covers few columns so the strided access from column-major layout is tolerable, and the .t().contiguous().t() copy makes those accesses contiguous. With large BLOCK_SIZE (128/256), each block covers many more columns — row-major input means these columns are contiguous in memory, enabling efficient wide loads. But column-major input with large blocks creates large strides that kill performance.

Benchmarks

End-to-end training with torchtitan on 8x MI300X with DeepSeek-MoE-16B (EP=8, seq_len=4096, torch.compile enabled):

Batch Size Baseline TPS (#3952) Optimized TPS Speedup
1 136 642 4.7x
4 500 1,865 3.7x

Profiler analysis shows _triton_fp8_per_group_colwise_scales_kernel was the dominant bottleneck (83% of GPU time at baseline). After optimization, kernel time dropped from 11.8s to 4.1s with 10x fewer kernel launches (86,265 → 8,671).

Test plan

  • Verified training convergence (loss decreasing normally) with torchtitan DeepSeek-MoE-16B
  • Profiled with PyTorch profiler to confirm kernel speedup
  • Ablation study confirming both changes are needed together
  • Run existing unit tests: pytest test/prototype/moe_training/ -v

🤖 Generated with Claude Code

Two interdependent changes that together yield ~4.3x end-to-end training
throughput improvement on MI300X for DeepSeek-MoE-16B:

1. Remove redundant .t().contiguous().t() memory copies before calling
   triton_fp8_per_group_colwise_scales in the backward pass. The kernel
   already handles arbitrary strides via its stride parameters, so these
   full-tensor copies are unnecessary.

2. Use larger Triton autotune configs (BLOCK_SIZE=128/256, BLOCK_SIZE_ITER=
   128/256) for the colwise scales kernel on AMD GPUs. With row-major input
   (from change 1), larger block sizes enable contiguous column access
   patterns, reducing grid block count by 4-8x.

Benchmarked on 8x MI300X with DeepSeek-MoE-16B (EP=8, seq_len=4096):
- Batch size 1: 136 TPS -> 642 TPS (4.7x)
- Batch size 4: 500 TPS -> 2153 TPS (4.3x)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 2, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3972

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit b455246 with merge base 4ae435e (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla
Copy link
Copy Markdown

meta-cla bot commented Mar 2, 2026

Hi @lizamd!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@meta-cla
Copy link
Copy Markdown

meta-cla bot commented Mar 2, 2026

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 2, 2026
@danielvegamyhre danielvegamyhre self-requested a review March 2, 2026 03:30
@danielvegamyhre
Copy link
Copy Markdown
Contributor

danielvegamyhre commented Mar 2, 2026

Remove redundant .t().contiguous().t()

This is interesting. I originally added this as a temporary measure because I found that the quantization kernels were two times as fast with the tensor in col major, even with this memory layout transform. We pivoted to mxfp8 and i never got a chance to investigate further.

Are you able to benchmark this change on a H100 as well?

@lizamd
Copy link
Copy Markdown
Contributor Author

lizamd commented Mar 2, 2026

we can test on MI350 also. H100 is not tested yet, we can add it

@danielvegamyhre
Copy link
Copy Markdown
Contributor

we can test on MI350 also. H100 is not tested yet, we can add it

ok please let me know

@alex-minooka
Copy link
Copy Markdown
Contributor

we can test on MI350 also. H100 is not tested yet, we can add it

ok please let me know

image

Here is some data on h100, with the configurations listed in the image. PR gives slight performance improvement on 1 node.

@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow and removed mx labels Mar 2, 2026
@danielvegamyhre
Copy link
Copy Markdown
Contributor

@alex-minooka @lizamd looks great! will land once CI is green. for future reference, please include microbenchmarks in the PR description (e.g., in this could you would run ~/ao/benchmarks/prototype/moe_training/fp8_rowwise/bench_triton_fp8_per_group_colwise_scales.py - you can find other relevant kernel microbenchmarking scripts for fp8 moe training in there as well)

@danielvegamyhre danielvegamyhre merged commit 7bb7f06 into pytorch:main Mar 2, 2026
18 of 24 checks passed
brucechanglongxu added a commit to brucechanglongxu/ao that referenced this pull request Mar 7, 2026
PR pytorch#3952 expanded Triton autotune configurations for MoE FP8 rowwise
kernels on AMD GPUs (24-36 configs gated behind torch.version.hip).
Benchmarking on MI300X reveals this causes:

1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner
   selecting suboptimal configs from the noisy microbenchmark results
2. Non-deterministic config selection across runs
3. No measurable improvement on Llama4 shapes vs the original single
   config (the PR's reported gains were vs torch.compile, not vs the
   original Triton config)

Revert to the original single config for both atomic and reduction
kernels, which is near-optimal across all tested shape families.

This does NOT revert other valuable changes from pytorch#3952:
- N_GROUPS added to autotune key in jagged_float8_scales.py
- N_GROUPS: tl.int64 type annotation fixes

The jagged_float8_scales.py configs (from PR pytorch#3972) are also preserved,
as they were carefully benchmarked and provide 4.3x improvement.

Benchmark results on MI300X (atomic kernel, representative shapes):

| Shape             | Expanded (pytorch#3952) | Single (this PR) |
|-------------------|------------------|-------------------|
| (128, 8192, 5120) | 10.56 ms         | 10.43 ms          |
| (128, 5120, 8192) | 10.50 ms         | 10.40 ms          |
| (8, 2048, 1408)   | 0.068 ms         | 0.072 ms          |
| (8, 1408, 2048)   | 0.069 ms         | 0.078 ms          |
| Cold-cache overhead| 4.4s            | 1.9s              |
danielvegamyhre pushed a commit that referenced this pull request Mar 7, 2026
…pes (#4024)

PR #3952 expanded Triton autotune configurations for MoE FP8 rowwise
kernels on AMD GPUs (24-36 configs gated behind torch.version.hip).
Benchmarking on MI300X reveals this causes:

1. ~15% kernel regression on DeepSeek V3 shapes due to the autotuner
   selecting suboptimal configs from the noisy microbenchmark results
2. Non-deterministic config selection across runs
3. No measurable improvement on Llama4 shapes vs the original single
   config (the PR's reported gains were vs torch.compile, not vs the
   original Triton config)

Revert to the original single config for both atomic and reduction
kernels, which is near-optimal across all tested shape families.

This does NOT revert other valuable changes from #3952:
- N_GROUPS added to autotune key in jagged_float8_scales.py
- N_GROUPS: tl.int64 type annotation fixes

The jagged_float8_scales.py configs (from PR #3972) are also preserved,
as they were carefully benchmarked and provide 4.3x improvement.

Benchmark results on MI300X (atomic kernel, representative shapes):

| Shape             | Expanded (#3952) | Single (this PR) |
|-------------------|------------------|-------------------|
| (128, 8192, 5120) | 10.56 ms         | 10.43 ms          |
| (128, 5120, 8192) | 10.50 ms         | 10.40 ms          |
| (8, 2048, 1408)   | 0.068 ms         | 0.072 ms          |
| (8, 1408, 2048)   | 0.069 ms         | 0.078 ms          |
| Cold-cache overhead| 4.4s            | 1.9s              |
@danielvegamyhre danielvegamyhre added this to the FP8 Rowwise Training milestone Mar 11, 2026
wenchenvincent added a commit to wenchenvincent/ao that referenced this pull request Mar 19, 2026
…se_scales to be row major to reflect the usage of this op after pytorch#3972 was merged.
wenchenvincent added a commit to wenchenvincent/ao that referenced this pull request Mar 27, 2026
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark
script from upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
wenchenvincent added a commit to wenchenvincent/ao that referenced this pull request Apr 7, 2026
Reflects actual usage after pytorch#3972. Also add dual kernel benchmark
script from upstream.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants