[mxfp8 moe training] increase num_warps in mxfp8 a2a comms kernel #3087
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Higher num warps to hide nvlink latency and increase occupancy. Runtime improves by ~13x in benchmark from ~2000ms to ~150ms. (num_warps was originally 16 but changed it to 1 for debugging something, and am now changing it back - just providing context).
Still, perf is much slower than a standalone
all_to_all_single_autograd
- 10ms vs 150ms for shape (16,8192,5120) with 8 splits. However, the benefit is we avoid the d2h sync required to compute input/output splits for all_to_all_single and get them on the CPU/host as the impl requires.TODO: bench against bf16 a2a triton+symmetric memory impl to verify the dynamic quant saving network bandwidth is actually an improvement in practice.