Skip to content

Commit

Permalink
[Inductor][Flex-attention] Make num_head support dynamic (#126342)
Browse files Browse the repository at this point in the history
Fixes #ISSUE_NUMBER

Pull Request resolved: #126342
Approved by: https://github.com/drisspg
  • Loading branch information
yanboliang authored and pytorchmergebot committed May 16, 2024
1 parent f9d107a commit 0f8380d
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions torch/nn/attention/_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def score_mod(
"""

if torch.compiler.is_dynamo_compiling():
# mark head_dim and dim always to be static
# mark head_dim always to be static
for x in [query, key, value]:
torch._dynamo.mark_static(x, 1)
torch._dynamo.mark_static(x, -1)
out, _ = flex_attention_hop(query, key, value, score_mod)
return out
Expand Down

0 comments on commit 0f8380d

Please sign in to comment.