Skip to content

Commit

Permalink
Made some minor improvements to flexattention perf + added more autot…
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee authored and pytorchmergebot committed May 25, 2024
1 parent 9f11fc6 commit 84e59f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
1 change: 0 additions & 1 deletion test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,6 @@ def score_mod(score, b, h, m, n):
self.run_test(score_mod)

@supported_platform
@skip("TODO: Figure out why this is erroring")
@patch.object(torch._inductor.config, "max_autotune", True)
def test_max_autotune_with_captured(self):
head_scale = torch.randn(H, device="cuda")
Expand Down
19 changes: 8 additions & 11 deletions torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,8 @@ def build_subgraph_buffer(
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.dot(q, k.to(MATMUL_PRECISION), acc=qk)
qk = tl.dot(q, k)
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = offs_m[:, None]
n = start_n + offs_n[None, :]
Expand Down Expand Up @@ -277,7 +275,8 @@ def build_subgraph_buffer(
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc = tl.dot(p.to(MATMUL_PRECISION), v.to(MATMUL_PRECISION), acc)
v = tl.load(V_block_ptr)
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
Expand Down Expand Up @@ -402,13 +401,11 @@ def flex_attention(*args, **kwargs):
configs: List[Tuple[int, int, int, int]] = []
configs.append(_get_default_config_fwd(query))
if config.max_autotune:
configs += [
(128, 64, 4, 3),
(128, 128, 4, 3),
(128, 128, 8, 2),
(64, 128, 4, 3),
(64, 64, 4, 3),
]
for BM in [64, 128]:
for BN in [64, 128]:
for s in [3, 4, 7]:
for w in [4, 8]:
configs.append((BM, BN, w, s))

# Note, we don't need to pass in the captured buffers explicitly
# because they're implicitly added by the score_mod function
Expand Down

1 comment on commit 84e59f0

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #126811 on behalf of https://github.com/PaliC due to breaking on V100s / internal tests (comment)

Please sign in to comment.