Skip to content

Conversation

@chengyupku
Copy link
Contributor

This pull request includes significant changes to add a new function for rewriting WGMMAs (Warp Group Matrix-Matrix Multiply Accumulate) synchronization and to enhance the handling of thread tags in the codebase. The most important changes include adding a new file for the WGMMASyncRewriter class, modifying the ThreadTagChecker class, and updating the optimization pipeline to include the new WGMMASyncRewriter pass.

New functionality for WGMMASyncRewriter:

  • src/transform/wgmma_sync_rewriter.cc: Added a new file that defines the WgmmaSyncRewriter class, which includes methods for detecting GEMM (General Matrix Multiply) operations and rewriting synchronization points. This class ensures that the synchronization is optimized for CUDA GPUs (sm90+).

Enhancements to thread tag handling:

  • src/transform/warp_specialized_rewriter.cc: Modified the ThreadTagChecker class to improve the handling of thread tags, specifically checking for threadIdx.y and threadIdx.z and ensuring the extent is one. This change aims to enhance the validation of thread bindings. [1] [2]

Updates to the optimization pipeline:

  • tilelang/engine/phase.py: Updated the OptimizeForTarget function to include the new RewriteWgmmaSync pass in the optimization pipeline, ensuring that the new synchronization rewriting is applied during the optimization process.
  • tilelang/transform/__init__.py: Added a new function definition for RewriteWgmmaSync to expose the new pass as part of the transformation utilities.

New example for flash attention:

…HA WGMMA pipelined example (FA3-like scheduling)

This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler:

- Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc`
- Added pass registration for `RewriteWgmmaSync`
- Updated `tilelang/engine/phase.py` to include the new transformation pass
- Updated `tilelang/transform/__init__.py` to expose the new pass

The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels.
Improve thread tag validation in warp specialized rewriter to prevent unintended transformations:
- Add more precise checks for threadIdx.y and threadIdx.z
- Validate thread extent to ensure only single-extent thread bindings are allowed
- Prevent warp specialization for multi-extent thread bindings in y and z dimensions
@chengyupku
Copy link
Contributor Author

CI failed due to OOM failure on the test node, merged.

@chengyupku chengyupku merged commit 0a2b781 into tile-ai:main Feb 28, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant