-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Flex attention supports dynamic shape #125994
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,6 +126,19 @@ def score_mod(score, b, h, m, n): | |
|
||
|
||
class TestTemplatedSDPA(InductorTestCase): | ||
def _check_equal(self, golden_out, ref_out, compiled_out, dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wrote something pretty similar lol: |
||
compiled_error = (golden_out - compiled_out).abs().mean() | ||
ref_error = (golden_out - ref_out).abs().mean() | ||
# Note, it seems like we really are less accurate than the float32 | ||
# computation, likely due to the online softmax | ||
if dtype == torch.float32: | ||
fudge_factor = 10.0 | ||
else: | ||
fudge_factor = 1.1 | ||
if compiled_error > ref_error * fudge_factor: | ||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." | ||
self.assertTrue(False, msg) | ||
|
||
def run_test( | ||
self, | ||
score_mod: Callable, | ||
|
@@ -145,25 +158,135 @@ def run_test( | |
) | ||
ref_out = sdpa_partial(q, k, v) | ||
compiled_out = compiled_sdpa(q, k, v) | ||
self._check_equal(golden_out, ref_out, compiled_out, dtype) | ||
|
||
compiled_error = (golden_out - compiled_out).abs().mean() | ||
ref_error = (golden_out - ref_out).abs().mean() | ||
# Note, it seems like we really are less accurate than the float32 | ||
# computation, likely due to the online softmax | ||
if dtype == torch.float32: | ||
fudge_factor = 10.0 | ||
else: | ||
fudge_factor = 1.1 | ||
if compiled_error > ref_error * fudge_factor: | ||
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." | ||
self.assertTrue(False, msg) | ||
def run_dynamic_test( | ||
self, | ||
score_mod: Callable, | ||
dtype: torch.dtype = torch.float16, | ||
B: int = B, | ||
H: int = H, | ||
S: int = S, | ||
D: int = D, | ||
): | ||
sdpa_partial = create_attention(score_mod) | ||
# The first eager batch, shape (B, H, S, D) | ||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out1 = sdpa_partial( | ||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) | ||
) | ||
ref_out1 = sdpa_partial(q1, k1, v1) | ||
|
||
# The second eager batch, shape (B * 2, H, S / 2, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out2 = sdpa_partial( | ||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) | ||
) | ||
ref_out2 = sdpa_partial(q2, k2, v2) | ||
|
||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. | ||
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation. | ||
torch._dynamo.reset() | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Compiling with dynamic shape in the first batch. | ||
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) | ||
compiled_out1 = compiled_sdpa(q1, k1, v1) | ||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
# No re-compilation, use the compiled dynamic shape version. | ||
compiled_out2 = compiled_sdpa(q2, k2, v2) | ||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
def run_automatic_dynamic_test( | ||
self, | ||
score_mod: Callable, | ||
dtype: torch.dtype = torch.float16, | ||
B: int = B, | ||
H: int = H, | ||
S: int = S, | ||
D: int = D, | ||
): | ||
sdpa_partial = create_attention(score_mod) | ||
# The first eager batch, shape (B, H, S, D) | ||
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out1 = sdpa_partial( | ||
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64) | ||
) | ||
ref_out1 = sdpa_partial(q1, k1, v1) | ||
|
||
# The second eager batch, shape (B * 2, H, S / 2, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out2 = sdpa_partial( | ||
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64) | ||
) | ||
ref_out2 = sdpa_partial(q2, k2, v2) | ||
|
||
# The third eager batch, shape (B * 4, H, S / 4, D) | ||
B = int(B * 2) | ||
S = int(S / 2) | ||
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda") | ||
golden_out3 = sdpa_partial( | ||
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64) | ||
) | ||
ref_out3 = sdpa_partial(q3, k3, v3) | ||
|
||
# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing. | ||
# We check dynamo counters["frames"]["ok"] to ensure: | ||
# 1, the first batch is compiled with static shape | ||
# 2, the second batch is compiled with dynamic shape | ||
# 3, no re-compilation in the third batch | ||
torch._dynamo.reset() | ||
yanboliang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# The first batch. | ||
compiled_sdpa = torch.compile(sdpa_partial) | ||
compiled_out1 = compiled_sdpa(q1, k1, v1) | ||
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) | ||
|
||
# The second batch (automatic dynamic). | ||
compiled_out2 = compiled_sdpa(q2, k2, v2) | ||
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) | ||
|
||
# The third batch (no re-compilation). | ||
compiled_out3 = compiled_sdpa(q3, k3, v3) | ||
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype) | ||
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable): | ||
self.run_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable): | ||
self.run_dynamic_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
@common_utils.parametrize("score_mod", test_score_mods) | ||
def test_builtin_score_mods_automatic_dynamic( | ||
self, dtype: torch.dtype, score_mod: Callable | ||
): | ||
self.run_automatic_dynamic_test(score_mod, dtype) | ||
|
||
@supported_platform | ||
@common_utils.parametrize("dtype", test_dtypes) | ||
def test_skip_odd_keys(self, dtype: torch.dtype): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -617,3 +617,7 @@ def is_from_defaults(source: Source): | |
if isinstance(source, ChainedSource): | ||
return is_from_defaults(source.base) | ||
return False | ||
|
||
|
||
def is_cell_contents(source: Source): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this doing out of curiosity? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is part of heuristic rules that determinate if we should wrap int as symint. Here we are saying if the value is from a cell closures, we would not make it dynamic since cell closures usually are constant. We define these heuristics based on |
||
return isinstance(source, AttrSource) and source.member == "cell_contents" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Above in this file is
does compile ignore this if dynamic=true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
dynamic=True
means forcing dynamic.