Conversation
There was a problem hiding this comment.
Pull request overview
This pull request improves the Multi-Head Attention (MHA) pattern optimization by adding support for the NoT (No Transpose) variant, which occurs when using FusedMatMul operators instead of regular MatMul with Transpose. The PR corrects the capitalization from "noT" to "NoT" to align with established naming conventions (SW for Switch Where, GQA for Group Query Attention), and adds a comprehensive test to verify the optimization works correctly with FusedMatMul.
Changes:
- Fixed capitalization of the no-transpose suffix from "noT" to "NoT" in FunctionAttentionPattern.apply()
- Added "NoT_to" prefix to MultiHeadAttention3DPattern's _prefixes_operator_name to recognize LocalAttentionNoT operators
- Added test case test_multi_head_attention_fused_matmul to verify FusedMatMul-based attention patterns are correctly optimized to MultiHeadAttention
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| experimental_experiment/xoptim/patterns_ort/llm_optim.py | Added NoT_to prefix to MultiHeadAttention3DPattern to recognize LocalAttentionNoT variants created when FusedMatMul is used |
| experimental_experiment/xoptim/patterns/onnx_attention.py | Fixed capitalization from "noT" to "NoT" in FunctionAttentionPattern.apply() for consistency, added clarifying comment for FusedMatMul branch |
| _unittests/ut_xoptim/test_graph_pattern_optimization_ort.py | Added comprehensive test for FusedMatMul-based multi-head attention pattern optimization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _prefixes_operator_name = ( | ||
| f"{FunctionAttentionPattern._operator_name}_to", | ||
| f"{FunctionAttentionPattern._operator_name}sQ_to", | ||
| f"{FunctionAttentionPattern._operator_name}SW_to", | ||
| f"{FunctionAttentionPattern._operator_name}SWsQ_to", | ||
| f"{FunctionAttentionPattern._operator_name}NoT_to", |
There was a problem hiding this comment.
The _prefixes_operator_name tuple is missing combination prefixes. Since both SW and NoT suffixes can occur together (when switch_where=True and transpose=None), the pattern should also include:
f"{FunctionAttentionPattern._operator_name}SWNoT_to",
Additionally, lines 1107 and 1109 appear incorrect. FunctionAttentionPattern never creates "LocalAttentionsQ" or "LocalAttentionSWsQ" nodes. The "sQ" suffix is part of GQA variants (GQAsQ), not standalone patterns. These lines should likely be removed or replaced with GQA-related prefixes.
There was a problem hiding this comment.
ok if you can add unittest for each of those
No description provided.