Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inductor] Flex attention supports dynamic shape #125994

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 16 additions & 5 deletions benchmarks/transformer/score_mod.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import itertools
from collections import defaultdict
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -98,7 +99,7 @@ def generate_inputs(
return query, key, value


def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Above in this file is

torch._dynamo.config.automatic_dynamic_shapes = False

does compile ignore this if dynamic=true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, dynamic=True means forcing dynamic.

def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults:
device = torch.device("cuda")
query, key, value = generate_inputs(
config.batch_size,
Expand All @@ -113,7 +114,7 @@ def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
def eager_sdpa(query, key, value, _):
return F.scaled_dot_product_attention(query, key, value)

compiled_sdpa = torch.compile(_flex_attention)
compiled_sdpa = torch.compile(_flex_attention, dynamic=dynamic)

score_mod = config.score_mod

Expand Down Expand Up @@ -242,16 +243,26 @@ def generate_experiment_configs() -> List[ExperimentConfig]:
return all_configs


def main():
def main(dynamic=False):
seed = 123
np.random.seed(seed)
torch.manual_seed(seed)
results = []
for config in tqdm(generate_experiment_configs()):
results.append(Experiment(config, run_single_experiment(config)))
results.append(
Experiment(config, run_single_experiment(config, dynamic=dynamic))
)

print_results(results)


if __name__ == "__main__":
main()
parser = argparse.ArgumentParser()
parser.add_argument(
"--dynamic",
action="store_true",
help="Runs a dynamic shapes version of compiled flex attention.",
)

args = parser.parse_args()
main(args.dynamic)
145 changes: 134 additions & 11 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,19 @@ def score_mod(score, b, h, m, n):


class TestTemplatedSDPA(InductorTestCase):
def _check_equal(self, golden_out, ref_out, compiled_out, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)

def run_test(
self,
score_mod: Callable,
Expand All @@ -145,25 +158,135 @@ def run_test(
)
ref_out = sdpa_partial(q, k, v)
compiled_out = compiled_sdpa(q, k, v)
self._check_equal(golden_out, ref_out, compiled_out, dtype)

compiled_error = (golden_out - compiled_out).abs().mean()
ref_error = (golden_out - ref_out).abs().mean()
# Note, it seems like we really are less accurate than the float32
# computation, likely due to the online softmax
if dtype == torch.float32:
fudge_factor = 10.0
else:
fudge_factor = 1.1
if compiled_error > ref_error * fudge_factor:
msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X."
self.assertTrue(False, msg)
def run_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)

# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)

# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure there is no re-compilation.
torch._dynamo.reset()
yanboliang marked this conversation as resolved.
Show resolved Hide resolved
# Compiling with dynamic shape in the first batch.
compiled_sdpa = torch.compile(sdpa_partial, dynamic=True)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

# No re-compilation, use the compiled dynamic shape version.
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

def run_automatic_dynamic_test(
self,
score_mod: Callable,
dtype: torch.dtype = torch.float16,
B: int = B,
H: int = H,
S: int = S,
D: int = D,
):
sdpa_partial = create_attention(score_mod)
# The first eager batch, shape (B, H, S, D)
q1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v1 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out1 = sdpa_partial(
q1.to(torch.float64), k1.to(torch.float64), v1.to(torch.float64)
)
ref_out1 = sdpa_partial(q1, k1, v1)

# The second eager batch, shape (B * 2, H, S / 2, D)
B = int(B * 2)
S = int(S / 2)
q2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v2 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out2 = sdpa_partial(
q2.to(torch.float64), k2.to(torch.float64), v2.to(torch.float64)
)
ref_out2 = sdpa_partial(q2, k2, v2)

# The third eager batch, shape (B * 4, H, S / 4, D)
B = int(B * 2)
S = int(S / 2)
q3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
k3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
v3 = torch.randn((B, H, S, D), dtype=dtype, device="cuda")
golden_out3 = sdpa_partial(
q3.to(torch.float64), k3.to(torch.float64), v3.to(torch.float64)
)
ref_out3 = sdpa_partial(q3, k3, v3)

# Need to clear dynamo counters, since flex attention eager mode also uses dynamo tracing.
# We check dynamo counters["frames"]["ok"] to ensure:
# 1, the first batch is compiled with static shape
# 2, the second batch is compiled with dynamic shape
# 3, no re-compilation in the third batch
torch._dynamo.reset()
yanboliang marked this conversation as resolved.
Show resolved Hide resolved
# The first batch.
compiled_sdpa = torch.compile(sdpa_partial)
compiled_out1 = compiled_sdpa(q1, k1, v1)
self._check_equal(golden_out1, ref_out1, compiled_out1, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)

# The second batch (automatic dynamic).
compiled_out2 = compiled_sdpa(q2, k2, v2)
self._check_equal(golden_out2, ref_out2, compiled_out2, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)

# The third batch (no re-compilation).
compiled_out3 = compiled_sdpa(q3, k3, v3)
self._check_equal(golden_out3, ref_out3, compiled_out3, dtype)
self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_dynamic(self, dtype: torch.dtype, score_mod: Callable):
self.run_dynamic_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
@common_utils.parametrize("score_mod", test_score_mods)
def test_builtin_score_mods_automatic_dynamic(
self, dtype: torch.dtype, score_mod: Callable
):
self.run_automatic_dynamic_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
Expand Down
4 changes: 4 additions & 0 deletions torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,7 @@ def is_from_defaults(source: Source):
if isinstance(source, ChainedSource):
return is_from_defaults(source.base)
return False


def is_cell_contents(source: Source):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this doing out of curiosity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is part of heuristic rules that determinate if we should wrap int as symint. Here we are saying if the value is from a cell closures, we would not make it dynamic since cell closures usually are constant. We define these heuristics based on source.

return isinstance(source, AttrSource) and source.member == "cell_contents"
2 changes: 2 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
ConvertIntSource,
GetItemSource,
GradSource,
is_cell_contents,
is_constant_source,
is_from_defaults,
is_from_optimizer_source,
Expand Down Expand Up @@ -1165,6 +1166,7 @@ def wrap_literal(self, value):
# NN modules on the fly)
or self.source.guard_source().is_nn_module()
or is_from_defaults(self.source)
or is_cell_contents(self.source)
):
self.install_guards(GuardBuilder.CONSTANT_MATCH)
return ConstantVariable.create(value=value, source=self.source)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/kernel/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta):

# TODO generalize and add proper mask support
mask = (idx_m != -1) & (idx_d != -1)
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc")}}
{{store_output(("idx_z", "idx_h", "idx_m", "idx_d"), "acc", "mask")}}
yanboliang marked this conversation as resolved.
Show resolved Hide resolved

# TODO dont want to write this if we dont require grad
if OUTPUT_LOGSUMEXP:
Expand Down
4 changes: 4 additions & 0 deletions torch/nn/attention/_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def score_mod(
"""

if torch.compiler.is_dynamo_compiling():
# mark head_dim and dim always to be static
for x in [query, key, value]:
torch._dynamo.mark_static(x, 1)
yanboliang marked this conversation as resolved.
Show resolved Hide resolved
torch._dynamo.mark_static(x, -1)
out, _ = flex_attention_hop(query, key, value, score_mod)
return out

Expand Down