-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Flex attention supports dynamic shape #125994
Closed
Closed
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,6 @@ | |
_causal, | ||
_compose, | ||
_flex_attention, | ||
_generate_alibi_bias, | ||
_identity, | ||
_rel_bias, | ||
_rel_causal, | ||
|
@@ -64,7 +63,7 @@ def create_attention(score_mod): | |
_causal, | ||
_rel_bias, | ||
_rel_causal, | ||
_generate_alibi_bias(8), | ||
# _generate_alibi_bias(8), | ||
] | ||
|
||
|
||
|
@@ -126,6 +125,19 @@ def score_mod(score, b, h, m, n): | |
|
||
|
||
class TestTemplatedSDPA(InductorTestCase): | ||
def _check_equal(self, golden_out, ref_out, compiled_out, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrote something pretty similar lol: |
||
compiled_error = (golden_out - compiled_out).abs().mean() | ||
ref_error = (golden_out - ref_out).abs().mean() | ||
# Note, it seems like we really are less accurate than the float32 | ||
# computation, likely due to the online softmax | ||
if dtype == torch.float32: | ||
fudge_factor = 10.0 | ||
else: | ||
fudge_factor = 1.1 | ||
if compiled_error > ref_error * fudge_factor: | ||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." | ||
self.assertTrue(False, msg) | ||
|
||
def run_test( | ||
self, | ||
score_mod: Callable, | ||
|
@@ -145,25 +157,128 @@ def run_test( | |
) | ||
ref_out = sdpa_partial(q, k, v) | ||
compiled_out = compiled_sdpa(q, k, v) | ||
self._check_equal(golden_out, ref_out, compiled_out, dtype) | ||
|
||
compiled_error = (golden_out - compiled_out).abs().mean() | ||
ref_error = (golden_out - ref_out).abs().mean() | ||
# Note, it seems like we really are less accurate than the float32 | ||
# computation, likely due to the online softmax | ||
if dtype == torch.float32: | ||
fudge_factor = 10.0 | ||
else: | ||
fudge_factor = 1.1 | ||
if compiled_error > ref_error * fudge_factor: | ||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." | ||
self.assertTrue(False, msg) | ||
def run_dynamic_test( | ||
self, | ||
score_mod: Callable, | ||
dtype: torch.dtype = torch.float16, | ||
B: int = B, | ||
H: int = H, | ||
S: int = S, | ||
D: int = D, | ||
): | ||
sdpa_partial = create_attention(score_mod) | ||
# The first eager batch, shape (B, H, S, D) | ||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out1 = sdpa_partial( | ||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) | ||
) | ||
ref_out1 = sdpa_partial(q1, k1, v1) | ||
|
||
# The second eager batch, shape (B * 2, H, S / 2, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out2 = sdpa_partial( | ||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) | ||
) | ||
ref_out2 = sdpa_partial(q2, k2, v2) | ||
|
||
torch._dynamo.reset() | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Compiling with dynamic shape in the first batch. | ||
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) | ||
compiled_out1 = compiled_sdpa(q1, k1, v1) | ||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
# No re-compilation, use the compiled dynamic shape version. | ||
compiled_out2 = compiled_sdpa(q2, k2, v2) | ||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
def run_automatic_dynamic_test( | ||
self, | ||
score_mod: Callable, | ||
dtype: torch.dtype = torch.float16, | ||
B: int = B, | ||
H: int = H, | ||
S: int = S, | ||
D: int = D, | ||
): | ||
sdpa_partial = create_attention(score_mod) | ||
# The first eager batch, shape (B, H, S, D) | ||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out1 = sdpa_partial( | ||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) | ||
) | ||
ref_out1 = sdpa_partial(q1, k1, v1) | ||
|
||
# The second eager batch, shape (B * 2, H, S / 2, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out2 = sdpa_partial( | ||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) | ||
) | ||
ref_out2 = sdpa_partial(q2, k2, v2) | ||
|
||
# The third eager batch, shape (B * 4, H, S / 4, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out3 = sdpa_partial( | ||
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) | ||
) | ||
ref_out3 = sdpa_partial(q3, k3, v3) | ||
|
||
torch._dynamo.reset() | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Compiling with static shape in the first batch. | ||
compiled_sdpa = torch.compile(sdpa_partial) | ||
compiled_out1 = compiled_sdpa(q1, k1, v1) | ||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
# Automatic compiling with dynamic shape in the second batch. | ||
compiled_out2 = compiled_sdpa(q2, k2, v2) | ||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) | ||
|
||
# No re-compilation, use the compiled dynamic shape version. | ||
compiled_out3 = compiled_sdpa(q3, k3, v3) | ||
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): | ||
self.run_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable): | ||
self.run_dynamic_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods_automatic_dynamic( | ||
self, dtype: torch.dtype, score_mod: Callable | ||
): | ||
self.run_automatic_dynamic_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
def test_skip_odd_keys(self, dtype: torch.dtype): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above in this file is
does compile ignore this if dynamic=true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
dynamic=True
means forcing dynamic.