[Dev][Bugfix] Fix bug in ThreadTagChecker; Add WgmmaSync rewriter and add MHA WGMMA pipelined example #128
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This pull request includes significant changes to add a new function for rewriting WGMMAs (Warp Group Matrix-Matrix Multiply Accumulate) synchronization and to enhance the handling of thread tags in the codebase. The most important changes include adding a new file for the WGMMASyncRewriter class, modifying the ThreadTagChecker class, and updating the optimization pipeline to include the new WGMMASyncRewriter pass.
New functionality for WGMMASyncRewriter:
src/transform/wgmma_sync_rewriter.cc: Added a new file that defines theWgmmaSyncRewriterclass, which includes methods for detecting GEMM (General Matrix Multiply) operations and rewriting synchronization points. This class ensures that the synchronization is optimized for CUDA GPUs (sm90+).Enhancements to thread tag handling:
src/transform/warp_specialized_rewriter.cc: Modified theThreadTagCheckerclass to improve the handling of thread tags, specifically checking forthreadIdx.yandthreadIdx.zand ensuring the extent is one. This change aims to enhance the validation of thread bindings. [1] [2]Updates to the optimization pipeline:
tilelang/engine/phase.py: Updated theOptimizeForTargetfunction to include the newRewriteWgmmaSyncpass in the optimization pipeline, ensuring that the new synchronization rewriting is applied during the optimization process.tilelang/transform/__init__.py: Added a new function definition forRewriteWgmmaSyncto expose the new pass as part of the transformation utilities.New example for flash attention:
examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py: Added a new example script that demonstrates the use of flash attention with pipelined WGMMAs. This script includes configuration setup, kernel functions, and a reference program for benchmarking.