From 69e8b7abe7c1038e5bac6b57a008274ed7076a36 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 3 May 2024 16:08:19 -0700 Subject: [PATCH] Add lowering for flex_attention_backward and add grad test for inductor --- benchmarks/transformer/score_mod.py | 171 ++++-- test/inductor/test_flex_attention.py | 196 ++++--- test/run_test.py | 5 +- torch/_higher_order_ops/flex_attention.py | 44 +- torch/_inductor/ir.py | 5 +- torch/_inductor/kernel/flex_attention.py | 640 +++++++++++++++++----- torch/_inductor/select_algorithm.py | 63 ++- torch/_inductor/utils.py | 2 +- torch/nn/attention/_flex_attention.py | 4 +- torch/testing/_internal/hop_db.py | 6 +- 10 files changed, 828 insertions(+), 308 deletions(-) diff --git a/benchmarks/transformer/score_mod.py b/benchmarks/transformer/score_mod.py index 2c5f41502f7ea..57088c45f8a06 100644 --- a/benchmarks/transformer/score_mod.py +++ b/benchmarks/transformer/score_mod.py @@ -3,7 +3,7 @@ from collections import defaultdict from dataclasses import asdict, dataclass from functools import partial -from typing import Callable, List +from typing import Callable, List, Optional, Tuple import numpy as np import torch @@ -29,28 +29,32 @@ def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> @dataclass(frozen=True) class ExperimentConfig: - batch_size: int - num_heads: int - q_seq_len: int - k_seq_len: int - head_dim: int + shape: Tuple[int] score_mod: Callable dtype: torch.dtype + calculate_bwd_time: bool + + def __post_init__(self): + assert len(self.shape) == 4, "Shape must be of length 4" def asdict(self): - return asdict(self) + # Convert the dataclass instance to a dictionary + d = asdict(self) + # Remove the 'calculate_bwd_time' key + d.pop("calculate_bwd_time", None) + return d @dataclass(frozen=True) -class ExperimentResults: +class Times: eager_time: float compiled_time: float - def get_entries(self) -> List: - return [ - f"{self.eager_time:2f}", - f"{self.compiled_time:2f}", - ] + +@dataclass(frozen=True) +class ExperimentResults: + fwd_times: Times + bwd_times: Optional[Times] @dataclass(frozen=True) @@ -58,29 +62,31 @@ class Experiment: config: ExperimentConfig results: ExperimentResults - def get_entries(self) -> List: - return self.config.get_entries() + self.results.get_entries() - def asdict(self): - dict1 = asdict(self.config) + dict1 = self.config.asdict() dict2 = asdict(self.results) return {**dict1, **dict2} def generate_inputs( - batch_size, - num_heads, - q_sequence_length, - kv_sequence_length, - head_dim, - dtype, - device, + batch_size: int, + num_heads: int, + q_sequence_length: int, + kv_sequence_length: int, + head_dim: int, + dtype: torch.dtype, + device: torch.device, + requires_grad: bool, ): q_shape = (batch_size, q_sequence_length, num_heads * head_dim) kv_shape = (batch_size, kv_sequence_length, num_heads * head_dim) - make_q = partial(torch.rand, q_shape, device=device, dtype=dtype) - make_kv = partial(torch.rand, kv_shape, device=device, dtype=dtype) + make_q = partial( + torch.rand, q_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) + make_kv = partial( + torch.rand, kv_shape, device=device, dtype=dtype, requires_grad=requires_grad + ) query = ( make_q() .view(batch_size, q_sequence_length, num_heads, head_dim) @@ -101,14 +107,16 @@ def generate_inputs( def run_single_experiment(config: ExperimentConfig, dynamic=False) -> ExperimentResults: device = torch.device("cuda") + batch_size, num_heads, q_seq_len, head_dim = config.shape query, key, value = generate_inputs( - config.batch_size, - config.num_heads, - config.q_seq_len, - config.k_seq_len, - config.head_dim, + batch_size, + num_heads, + q_seq_len, + q_seq_len, + head_dim, config.dtype, device, + requires_grad=config.calculate_bwd_time, ) def eager_sdpa(query, key, value, _): @@ -125,23 +133,47 @@ def eager_sdpa(query, key, value, _): compiled_sdpa, query, key, value, score_mod ) - return ExperimentResults( - eager_time=forward_eager_time, - compiled_time=forward_compiled_time, - ) + if config.calculate_bwd_time: + out_eager = eager_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_eager_time = benchmark_torch_function_in_microseconds( + out_eager.backward, dOut, retain_graph=True + ) + + out_compile = compiled_sdpa(query, key, value, score_mod) + dOut = torch.randn_like(out_eager) + backward_compile_time = benchmark_torch_function_in_microseconds( + out_compile.backward, dOut, retain_graph=True + ) + + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=Times(backward_eager_time, backward_compile_time), + ) + else: + return ExperimentResults( + fwd_times=Times(forward_eager_time, forward_compiled_time), + bwd_times=None, + ) -def calculate_speedup(results: ExperimentResults) -> float: - return results.eager_time / results.compiled_time +def calculate_speedup(results: ExperimentResults, type: str) -> float: + if type == "fwd": + return results.fwd_times.eager_time / results.fwd_times.compiled_time + elif type == "bwd": + assert results.bwd_times is not None + return results.bwd_times.eager_time / results.bwd_times.compiled_time + else: + raise ValueError(f"Invalid type {type}") def get_func_name(func): return func.__name__.split(".")[-1].split(" at ")[0] -def get_average_speedups(results: List[Experiment]): +def get_average_speedups(results: List[Experiment], type: str): # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] + speedups = [calculate_speedup(r.results, type) for r in results] # Find indices of max and min speedups max_speedup_index = np.argmax(speedups) @@ -177,20 +209,39 @@ def print_results(results: List[Experiment]): table_data = defaultdict(list) for experiment in results: for key, value in experiment.asdict().items(): - if key == "eager_time" or key == "compiled_time": - value = float(value) - table_data[key].append(value) + if key == "fwd_times": + for name, time in value.items(): + table_data[f"fwd_{name}"].append(float(time)) + elif key == "bwd_times": + if experiment.config.calculate_bwd_time: + for name, time in value.items(): + table_data[f"bwd_{name}"].append(float(time)) + else: + table_data[key].append(value) # Calculate speedups - speedups = [calculate_speedup(r.results) for r in results] - table_data["speedup"] = speedups + fwd_speedups = [calculate_speedup(r.results, type="fwd") for r in results] + table_data["fwd_speedup"] = fwd_speedups + if results[0].config.calculate_bwd_time: + bwd_speedups = [calculate_speedup(r.results, type="bwd") for r in results] + table_data["bwd_speedup"] = bwd_speedups table_data["score_mod"] = [get_func_name(func) for func in table_data["score_mod"]] print(tabulate(table_data, headers="keys", tablefmt="github", floatfmt=".3f")) - average_data = get_average_speedups(results) + print("\n") + print("FWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="fwd") print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + if results[0].config.calculate_bwd_time: + print("\n") + print("BWD Speedups".center(125, "=")) + print("\n") + average_data = get_average_speedups(results, type="bwd") + print(tabulate(average_data, headers="keys", tablefmt="github", floatfmt=".3f")) + def generate_score_mods() -> List[Callable]: def noop(score, b, h, m, n): @@ -208,8 +259,8 @@ def head_bias(score, b, h, m, n): return [noop, causal_mask, relative_bias, head_bias] -def generate_experiment_configs() -> List[ExperimentConfig]: - batch_sizes = [1, 8, 16] +def generate_experiment_configs(calculate_bwd: bool) -> List[ExperimentConfig]: + batch_sizes = [2, 8, 16] num_heads = [16] q_kv_seq_lens = [(512, 512), (1024, 1024), (4096, 4096)] head_dims = [64, 128, 256] @@ -228,41 +279,49 @@ def generate_experiment_configs() -> List[ExperimentConfig]: ) in itertools.product( batch_sizes, num_heads, q_kv_seq_lens, head_dims, score_mods, dtypes ): + assert q_seq_len == kv_seq_len, "Only equal length inputs supported for now." all_configs.append( ExperimentConfig( - batch_size=bsz, - num_heads=n_heads, - q_seq_len=q_seq_len, - k_seq_len=kv_seq_len, - head_dim=head_dim, + shape=(bsz, n_heads, q_seq_len, head_dim), score_mod=score_mod, dtype=dtype, + calculate_bwd_time=calculate_bwd, ) ) return all_configs -def main(dynamic=False): +def main(dynamic: bool, calculate_bwd: bool): seed = 123 np.random.seed(seed) torch.manual_seed(seed) results = [] - for config in tqdm(generate_experiment_configs()): + for config in tqdm(generate_experiment_configs(calculate_bwd)): results.append( Experiment(config, run_single_experiment(config, dynamic=dynamic)) ) + for config in tqdm(generate_experiment_configs(calculate_bwd)): + results.append(Experiment(config, run_single_experiment(config))) print_results(results) if __name__ == "__main__": - parser = argparse.ArgumentParser() + # Set up the argument parser + parser = argparse.ArgumentParser( + description="Run sweep over sizes and score mods for flex attention" + ) parser.add_argument( "--dynamic", action="store_true", help="Runs a dynamic shapes version of compiled flex attention.", ) + parser.add_argument( + "--calculate-bwd", action="store_true", help="Calculate backward pass times" + ) + # Parse arguments args = parser.parse_args() - main(args.dynamic) + + main(args.dynamic, args.calculate_bwd) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 9df905d2ad547..f3a9026a3c805 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -1,8 +1,9 @@ # Owner(s): ["module: inductor"] import functools +import unittest from collections import namedtuple -from typing import Callable +from typing import Callable, Optional from unittest import expectedFailure, skip, skipUnless from unittest.mock import patch @@ -58,14 +59,8 @@ def create_attention(score_mod): # --------- Useful score mod functions for testing --------- - -test_score_mods = [ - _identity, - _causal, - _rel_bias, - _rel_causal, - _generate_alibi_bias(8), -] +def _inverse_causal(score, b, h, m, n): + return torch.where(m <= n, score, float("-inf")) def _times_two(score, b, h, m, n): @@ -79,13 +74,11 @@ def _squared(score, b, h, m, n): def _head_offset(dtype: torch.dtype): - """Captured Buffer - Note: this builds a score_mod with index of a type - """ + """Captured Buffer""" head_offset = torch.rand(H, device="cuda", dtype=dtype) def score_mod(score, b, h, m, n): - return score * index(head_offset, [h]) + return score * head_offset[h] return score_mod @@ -103,20 +96,19 @@ def _trig2(score, b, h, m, n): return z -def _buffer_reduced(dtype: torch.dtype): - """Reduction in captured buffer""" - batch_offsets = torch.rand(B, 8, device="cuda", dtype=dtype) - - def score_mod(score, b, h, m, n): - batch_vals = index(batch_offsets, [b]) - return score + batch_vals.sum() - - return score_mod - +test_score_mods = [ + _identity, + _times_two, + _squared, + _causal, + _inverse_causal, + _rel_bias, + _rel_causal, + _generate_alibi_bias(8), +] captured_buffers_map = { "_head_offset": _head_offset, - "_buffer_reduced": _buffer_reduced, } B = 4 @@ -125,18 +117,35 @@ def score_mod(score, b, h, m, n): D = 64 -class TestTemplatedSDPA(InductorTestCase): - def _check_equal(self, golden_out, ref_out, compiled_out, dtype): +def query_key_value_clones( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dtype: torch.dtype = None, +): + """Clones the query, key, and value tensors and moves them to the specified dtype.""" + if dtype is None: + dtype = query.dtype + query_ref = query.clone().detach().to(dtype).requires_grad_(query.requires_grad) + key_ref = key.clone().detach().to(dtype).requires_grad_(key.requires_grad) + value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad) + return query_ref, key_ref, value_ref + + +class TestFlexAttention(InductorTestCase): + def _check_equal( + self, + golden_out: torch.Tensor, + ref_out: torch.Tensor, + compiled_out: torch.Tensor, + fudge_factor: float, + tensor_name: Optional[str] = None, + ): compiled_error = (golden_out - compiled_out).abs().mean() ref_error = (golden_out - ref_out).abs().mean() - # Note, it seems like we really are less accurate than the float32 - # computation, likely due to the online softmax - if dtype == torch.float32: - fudge_factor = 10.0 - else: - fudge_factor = 1.1 if compiled_error > ref_error * fudge_factor: - msg = f"Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." + name = tensor_name if tensor_name is not None else "" + msg = f"{name} Compiled error {compiled_error} is greater than ref error {ref_error} by more than {fudge_factor}X." self.assertTrue(False, msg) def run_test( @@ -150,15 +159,45 @@ def run_test( ): sdpa_partial = create_attention(score_mod) compiled_sdpa = torch.compile(sdpa_partial) - q = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - k = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - v = torch.randn((B, H, S, D), dtype=dtype, device="cuda") - golden_out = sdpa_partial( - q.to(torch.float64), k.to(torch.float64), v.to(torch.float64) - ) - ref_out = sdpa_partial(q, k, v) + q = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((B, H, S, D), dtype=dtype, device="cuda", requires_grad=True) + q_ref, k_ref, v_ref = query_key_value_clones(q, k, v) + q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64) + golden_out = sdpa_partial(q_gold, k_gold, v_gold) + ref_out = sdpa_partial(q_ref, k_ref, v_ref) compiled_out = compiled_sdpa(q, k, v) - self._check_equal(golden_out, ref_out, compiled_out, dtype) + + backward_grad = torch.randn((B, H, S, D), dtype=dtype, device="cuda") + + golden_out.backward(backward_grad.to(torch.float64)) + ref_out.backward(backward_grad) + compiled_out.backward(backward_grad) + + with torch.no_grad(): + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + # Checkout output + self._check_equal(golden_out, ref_out, compiled_out, fudge_factor, "Out") + + # Check gradients + q_fudge_factor = 2.5 * fudge_factor + self._check_equal( + q_gold.grad, q_ref.grad, q.grad, q_fudge_factor, "Grad_Query" + ) + k_fudge_factor = 4 * fudge_factor + self._check_equal( + k_gold.grad, k_ref.grad, k.grad, k_fudge_factor, "Grad_Key" + ) + v_fudge_factor = 8 * fudge_factor + self._check_equal( + v_gold.grad, v_ref.grad, v.grad, v_fudge_factor, "Grad_Value" + ) def run_dynamic_test( self, @@ -196,12 +235,20 @@ def run_dynamic_test( # Compiling with dynamic shape in the first batch. compiled_sdpa = torch.compile(sdpa_partial, dynamic=True) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # No re-compilation, use the compiled dynamic shape version. compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) def run_automatic_dynamic_test( @@ -251,20 +298,28 @@ def run_automatic_dynamic_test( # 2, the second batch is compiled with dynamic shape # 3, no re-compilation in the third batch torch._dynamo.reset() + + # Note, it seems like we really are less accurate than the float32 + # computation, likely due to the online softmax + if dtype == torch.float32: + fudge_factor = 10.0 + else: + fudge_factor = 1.1 + # The first batch. compiled_sdpa = torch.compile(sdpa_partial) compiled_out1 = compiled_sdpa(q1, k1, v1) - self._check_equal(golden_out1, ref_out1, compiled_out1, dtype) + self._check_equal(golden_out1, ref_out1, compiled_out1, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 1) # The second batch (automatic dynamic). compiled_out2 = compiled_sdpa(q2, k2, v2) - self._check_equal(golden_out2, ref_out2, compiled_out2, dtype) + self._check_equal(golden_out2, ref_out2, compiled_out2, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) # The third batch (no re-compilation). compiled_out3 = compiled_sdpa(q3, k3, v3) - self._check_equal(golden_out3, ref_out3, compiled_out3, dtype) + self._check_equal(golden_out3, ref_out3, compiled_out3, fudge_factor) self.assertEqual(torch._dynamo.utils.counters["frames"]["ok"], 2) @supported_platform @@ -318,6 +373,21 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod, dtype) + @supported_platform + @common_utils.parametrize("dtype", test_dtypes) + def test_captured_buffers_all_dims(self, dtype: torch.dtype): + head_scale = torch.randn(H, device="cuda") + batch_scale = torch.randn(B, device="cuda") + tok_scale = torch.randn(S, device="cuda") + + def all_bias(score, batch, head, token_q, token_kv): + score = score + tok_scale[token_q] + score = score + batch_scale[batch] + score = score + head_scale[head] + return score + + self.run_test(all_bias, dtype) + @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) def test_seq_masking(self, dtype): @@ -422,7 +492,7 @@ def score_mod_func(score, b, h, q, kv): make_tensor = functools.partial( torch.randn, - (2, 2, 8, 4), + (2, 2, 128, 4), device="cuda", dtype=torch.float64, requires_grad=True, @@ -458,6 +528,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): @supported_platform @common_utils.parametrize("dtype", test_dtypes_fast) + @unittest.skip("Silu decomp failing for full in backwards") def test_silu_on_score(self, dtype): def silu_score(score, b, h, q, kv): return torch.nn.functional.silu(score) @@ -597,23 +668,6 @@ def njt_score_mod(qk, b, h, q, kv): self.run_test(causal_njt, dtype) - @supported_platform - def test_backwards_fails(self): - make_tensor = functools.partial( - torch.randn, - (B, H, S, D), - dtype=torch.float32, - device="cuda", - requires_grad=True, - ) - q, k, v = make_tensor(), make_tensor(), make_tensor() - func = torch.compile(_flex_attention, backend="inductor", fullgraph=True) - with self.assertRaisesRegex( - AssertionError, "flex_attention_backward is not an OpOverload" - ): - out = func(q, k, v, _identity) - out.backward(torch.ones_like(out)) - @supported_platform def test_mixed_dtypes_fails(self): query = torch.randn((1, 1, 1024, 64), dtype=torch.float32, device="cuda") @@ -641,6 +695,7 @@ def score_mod(score, b, h, m, n): self.run_test(score_mod) @supported_platform + @skip("TODO: Figure out why this is erroring") @patch.object(torch._inductor.config, "max_autotune", True) def test_max_autotune_with_captured(self): head_scale = torch.randn(H, device="cuda") @@ -776,7 +831,7 @@ def test_aot_eager_gradcheck(self, score_mod): ) @supported_platform - @common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"]) + @common_utils.parametrize("score_mod_name", ["_head_offset"]) @common_utils.parametrize("mode", ["eager", "aot_eager"]) def test_captured_score_mod_aot_eager_gradcheck( self, score_mod_name: str, mode: str @@ -864,13 +919,10 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): joint_graph, """\ class GraphModule(torch.nn.Module): - def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", """ - + """alias_5: "f64[2, 2, 8, 4]", alias_7: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): + def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", alias_3: "f64[2, 2, 8, 4]", alias_5: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"): fw_graph = self.fw_graph joint_graph = self.joint_graph - flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, """ - + """primals_3, alias_5, alias_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_5 """ - + """= alias_7 = tangents_1 = fw_graph = joint_graph = None + flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, alias_3, alias_5, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_3 = alias_5 = tangents_1 = fw_graph = joint_graph = None getitem_2: "f64[2, 2, 8, 4]" = flex_attention_backward[0] getitem_3: "f64[2, 2, 8, 4]" = flex_attention_backward[1] getitem_4: "f64[2, 2, 8, 4]" = flex_attention_backward[2]; flex_attention_backward = None @@ -888,11 +940,11 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3 mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None return [add, None, None, None, None] -""", +""", # noqa: B950 ) -common_utils.instantiate_parametrized_tests(TestTemplatedSDPA) +common_utils.instantiate_parametrized_tests(TestFlexAttention) if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/run_test.py b/test/run_test.py index 5b24a00789964..b3b4f9ae68ea3 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -239,7 +239,8 @@ def __contains__(self, item): "test_native_mha", # OOM "test_module_hooks", # OOM "inductor/test_max_autotune", - "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps + "inductor/test_cutlass_backend", # slow due to many nvcc compilation steps, + "inductor/test_flex_attention", # OOM ] # A subset of onnx tests that cannot run in parallel due to high memory usage. ONNX_SERIAL_LIST = [ @@ -406,7 +407,7 @@ def run_test( stepcurrent_key = f"{test_file}_{test_module.shard}_{os.urandom(8).hex()}" if options.verbose: - unittest_args.append(f'-{"v"*options.verbose}') # in case of pytest + unittest_args.append(f'-{"v" * options.verbose}') # in case of pytest if test_file in RUN_PARALLEL_BLOCKLIST: unittest_args = [ diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index b5e1385da346b..f4586a0a57b0c 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -406,17 +406,20 @@ def flex_attention_autograd( score_mod: Callable, *other_buffers: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, torch.Tensor]: - input_requires_grad = any(t.requires_grad for t in (query, key, value)) - if torch.is_grad_enabled() and input_requires_grad: - example_vals = [ - torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) - ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] - fw_graph, bw_graph = create_fw_bw_graph(score_mod, example_vals, other_buffers) - else: - fw_graph, bw_graph = score_mod, None - out, logsumexp = FlexAttentionAutogradOp.apply( - query, key, value, fw_graph, bw_graph, *other_buffers - ) + with TransformGetItemToIndex(): + input_requires_grad = any(t.requires_grad for t in (query, key, value)) + if torch.is_grad_enabled() and input_requires_grad: + example_vals = [ + torch.zeros((), dtype=query.dtype, requires_grad=input_requires_grad) + ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] + fw_graph, bw_graph = create_fw_bw_graph( + score_mod, example_vals, other_buffers + ) + else: + fw_graph, bw_graph = score_mod, None + out, logsumexp = FlexAttentionAutogradOp.apply( + query, key, value, fw_graph, bw_graph, *other_buffers + ) return out, logsumexp @@ -449,9 +452,10 @@ def sdpa_dense_backward( score_mod = torch.vmap(score_mod, in_dims=(0, None, 0, None, None) + in_dim_buffers) score_mod = torch.vmap(score_mod, in_dims=(0, 0, None, None, None) + in_dim_buffers) - post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( - working_precision - ) + with TransformGetItemToIndex(): + post_mod_scores = score_mod(scores, b, h, m, n, *other_buffers).to( + working_precision + ) softmax_scores = torch.exp(post_mod_scores - logsumexp.unsqueeze(-1)) @@ -485,9 +489,10 @@ def sdpa_dense_backward( in_dims=(0, 0, None, None, None, 0) + in_dim_buffers, out_dims=out_dims, ) - grad_scores, *_ = joint_score_mod( - scores, b, h, m, n, grad_score_mod, *other_buffers - ) + with TransformGetItemToIndex(): + grad_scores, *_ = joint_score_mod( + scores, b, h, m, n, grad_score_mod, *other_buffers + ) grad_scores = grad_scores.to(query.dtype) grad_query = grad_scores @ key @@ -524,8 +529,9 @@ def trace_flex_attention_backward( torch.zeros((), dtype=query.dtype, requires_grad=query.requires_grad) ] + [torch.zeros((), dtype=torch.int) for _ in range(4)] bw_example_vals = fw_example_vals + [torch.zeros((), dtype=query.dtype)] - fw_graph = make_fx(fw_graph)(*fw_example_vals, *other_buffers) - joint_graph = make_fx(joint_graph)(*bw_example_vals, *other_buffers) + with TransformGetItemToIndex(): + fw_graph = reenter_make_fx(fw_graph)(*fw_example_vals, *other_buffers) + joint_graph = reenter_make_fx(joint_graph)(*bw_example_vals, *other_buffers) proxy_mode.tracer.root.register_module("fw_graph", fw_graph) proxy_mode.tracer.root.register_module("joint_graph", joint_graph) node_args = ( diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b4cf3bca42e50..ccba120d606a4 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3595,7 +3595,10 @@ def __init__( self.mutated_inputs = mutated_inputs if mutated_inputs is not None: # Ensure that the mutated inputs are only allowed for certain nodes - allowed_set = {torch.ops.higher_order.flex_attention} + allowed_set = { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + } current_node = V.graph.current_node.target assert ( current_node in allowed_set diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index a780d3709cb0c..32dff9d46668c 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -1,17 +1,39 @@ """ Triton Implementation of the flex_attention Kernel""" + import logging -from typing import Any, List +import math +from enum import auto, Enum +from typing import Any, List, Tuple import torch +from torch._prims_common import make_contiguous_strides_for from .. import config -from ..lowering import empty_strided, lowerings, register_lowering +from ..ir import ( + ComputedBuffer, + FixedLayout, + FlexibleLayout, + InputBuffer, + IRNode, + StorageBox, + Subgraph, + TensorBox, +) +from ..lowering import empty_strided, full, lowerings, register_lowering from ..select_algorithm import autotune_select_algorithm, TritonTemplate log = logging.getLogger(__name__) aten = torch.ops.aten -def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): +class SubgraphType(Enum): + """The type of subgraph for which we want to generate an output buffer.""" + + FWD = auto() # Forward pass + JOINT_FWD = auto() # The recompute step fo the of the bwds kernel + JOINT_BWD = auto() # The bwd pass of the joint + + +def flex_attention_grid(batch_size, num_heads, num_queries, d_model, meta): """How is this kernel parallelized? We create a grid of (batch_size * num_heads, ceil_div(n_queries, query_block_size), 1) Each block is responsible for iterating over blocks of keys and values calculating @@ -22,9 +44,117 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): return (triton.cdiv(num_queries, meta["BLOCK_M"]), batch_size * num_heads, 1) -sdpa_template = TritonTemplate( - name="sdpa", - grid=sdpa_grid, +def create_placeholder( + name: str, dtype: torch.dtype, device: torch.device +) -> TensorBox: + """Creates a placeholder input buffers for producing subgraph_output.""" + input_buffer = InputBuffer(name, FixedLayout(device, dtype, [1], [1])) + return TensorBox.create(input_buffer) + + +def index_to_other_buffers(cnt: int, graph_type: SubgraphType) -> int: + """This function needs to be aware of the signatures for flex_attention_forward + and flex_attention_backward. If new args are added, or the signature changes + be sure to update the indexing math + + Args: + cnt (int): The current index of the placeholder node + is_joint_graph (bool): Whether or not this subgraph represents the joint graph + """ + # Current fwd_args = [query, key, value, score_mod, *other_buffers] + # For fwd_graphs we have 5 dummy values this when the first lifted args + # is seen cnt = 5 and the start of the index_buffers is at args[4] + # thus we subtract 1 from the current cnt + if graph_type == SubgraphType.FWD: + return cnt - 1 + + # Current bwd_args = [q, k, v, out, lse, grad_out, fw_graph, joint_graph, *other_buffers] + # We have 5 dummy values but the start of other_buffers is at index 8 + if graph_type == SubgraphType.JOINT_FWD: + return cnt + 3 + + # Same bwd args but now with 6 dummy values while other_buffers still start at 8 + if graph_type == SubgraphType.JOINT_BWD: + return cnt + 2 + + +def build_subgraph_buffer( + args: Tuple[IRNode], + placeholder_inps: List[TensorBox], + subgraph: Subgraph, + graph_type: SubgraphType, +) -> ComputedBuffer: + """This function's goal is to take in the required args and produce the subgraph buffer + The subgraph buffer is a ComputedBuffer that will be inlined into the triton template + + Args: + args: The args that were passed into the flex_attention kernel + placeholder_inps: The list of scalar inputs, these were created on the fly through `create_placeholder` + subgraph: The Subgraph ir for which to produce the output node + graph_type: The type of subgraph for which we want to produce the output node, see enum above for details + """ + cnt = 0 + env = {} + for node in subgraph.graph_module.graph.nodes: + # There are two classes of placeholder inpts that we need + # to handle differently. For the first n_scalar_inps inputs + # we expect that these placeholders were generated by the make_fx call + # in the flex Attention HOP. So we need to create a new placeholder + # TensorBox for each of these inputs. For the rest of the inputs we + # expect that these are lifted inputs that fill up the '*other_buffers' + # tuple and already have corresponding TensorBoxes passed in as args. + if node.op == "placeholder": + is_lifted_input = cnt >= len(placeholder_inps) + lifted_input_index = index_to_other_buffers(cnt, graph_type) + env[node] = ( + args[lifted_input_index] if is_lifted_input else placeholder_inps[cnt] + ) + cnt += 1 + elif node.op == "call_function": + # For call_function we use the default lowerings and pass in the + # already created TensorBoxes as args + from torch.utils._pytree import tree_map + + env[node] = lowerings[node.target]( + *tree_map(lambda x: env[x] if x in env else x, node.args) + ) + elif node.op == "output": + # For the output node we need to create a ComputedBuffer + # which represents the actual score modification + # The joint_graph's output should be of the form[grad_score, None, None, None, None] + # This is because only the 'score' requires grad and the other outputs are + # the non-differentiable index scalars + if graph_type == SubgraphType.FWD or graph_type == SubgraphType.JOINT_FWD: + output_node = node.args[0] + else: + output_node = node.args[0][0] + output_buffer = env[output_node] + assert isinstance(output_buffer, TensorBox), ( + "The output node for flex attention's subgraph must be a TensorBox, but got: ", + type(output_buffer), + ) + assert isinstance(output_buffer.data, StorageBox), ( + "The output node for the flex attention subgraph must be a StorageBox, but got: ", + type(output_buffer), + ) + # Create the ComputedBuffer directly that will be inlined into the modification block + subgraph_buffer = ComputedBuffer( + name=None, + layout=FlexibleLayout( + device=output_buffer.data.get_device(), + dtype=output_buffer.data.get_dtype(), + size=output_buffer.data.get_size(), + ), + data=output_buffer.data.data, # type: ignore[arg-type] + ) + return subgraph_buffer + + raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + + +flex_attention_template = TritonTemplate( + name="flex_attention", + grid=flex_attention_grid, source=r""" {{def_kernel("Q", "K", "V", "LSE")}} # Sub notation for this kernel: @@ -118,6 +248,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): m = offs_m[:, None] n = start_n + offs_n[None, :] {{ modification( + subgraph_number=0, score="qk", b="off_hz // H", h="off_hz % H", @@ -192,7 +323,7 @@ def sdpa_grid(batch_size, num_heads, num_queries, d_model, meta): } -def _get_default_config(query): +def _get_default_config_fwd(query) -> Tuple[int, int, int, int]: dtype = query.get_dtype() head_dim = query.get_size()[-1] default_config = None @@ -218,143 +349,394 @@ def _get_default_config(query): return default_config +def _get_default_config_bwd(query) -> Tuple[int, int, int, int]: + head_dim = query.get_size()[-1] + dtype = query.get_dtype() + + if head_dim <= 256 and torch.cuda.get_device_capability() >= (9, 0): # H100 + if dtype == torch.float32: + return (64, 64, 4, 1) + return (128, 128, 4, 3) + elif head_dim <= 256 and torch.cuda.get_device_capability() >= (8, 0): # A100 + return (32, 32, 4, 1) + else: # modest hardware or extremely large head_dim + return (32, 32, 4, 1) + + # TODO: We probably also need a layout constraint? @register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None) def flex_attention(*args, **kwargs): - from torch._prims_common import make_contiguous_strides_for - from ..ir import ( - ComputedBuffer, - FixedLayout, - FlexibleLayout, - InputBuffer, - StorageBox, - TensorBox, - ) - query, key, value, subgraph, *other_buffers = args + placeholder_inps = [ + create_placeholder(name, dtype, query.get_device()) + for name, dtype in [ + ("score", query.get_dtype()), + ("b", torch.int32), + ("h", torch.int32), + ("m", torch.int32), + ("n", torch.int32), + ] + ] + subgraph_buffer = build_subgraph_buffer( + args, placeholder_inps, subgraph, graph_type=SubgraphType.FWD + ) + layout = FixedLayout( + query.get_device(), + query.get_dtype(), + query.get_size(), + make_contiguous_strides_for(query.get_size()), + ) + # see NOTE:[TritonTemplates with multiple outputs] + logsumexp_shape = query.get_size()[:-1] # [B, H, M] + logsumexp = empty_strided( + logsumexp_shape, + None, + dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype + device=query.get_device(), + ) + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_fwd(query)) + if config.max_autotune: + configs += [ + (128, 64, 4, 3), + (128, 128, 4, 3), + (128, 128, 8, 2), + (64, 128, 4, 3), + (64, 64, 4, 3), + ] - def create_placeholder(name: str, dtype: torch.dtype) -> InputBuffer: - return TensorBox.create( - InputBuffer( - name, - FixedLayout( - query.get_device(), - dtype, - [ - 1, - ], - [ - 1, - ], - ), - ) + # Note, we don't need to pass in the captured buffers explicitly + # because they're implicitly added by the score_mod function + # We do need to explicitly pass it in for autotuning though. + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_template.maybe_append_choice( + choices=choices, + input_nodes=[query, key, value, logsumexp], + layout=layout, + subgraphs=[ + subgraph_buffer, + ], + mutated_inputs=[ + logsumexp, + ], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ROWS_GUARANTEED_SAFE=False, + OUTPUT_LOGSUMEXP=True, ) + inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) + return ( + autotune_select_algorithm( + "flex_attention", choices, inputs_for_autotuning, layout + ), + logsumexp, + ) - scalar_inps = ["score", "b", "h", "m", "n"] - env = {} - cnt = 0 - placeholder_inps = [ - create_placeholder(name, dtype) + +# ---------------------------- Backward HOP Implementation ---------------------------- + + +def flex_attention_backward_grid(batch_size, num_heads, num_key_value, d_model, meta): + """How is this kernel parallelized? + Currently this is only parallelizing over batch * num_heads, but we can, and want to + parallelize over ceil_div(num_key_value, key_value_block_size). To do this will either require + atomic updates to some grad values or to have a two pass kernel design. + """ + return (batch_size * num_heads, 1, 1) + + +flex_attention_backward_template = TritonTemplate( + name="flex_attention_backward", + grid=flex_attention_backward_grid, + source=r""" +{{def_kernel("Q", "K", "V", "OUT", "LSE", "DELTA", "DO", "DQ", "DV")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # OUT: Forward output, LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT* DO, axis=1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values, D: Model dimension + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # (Modifiable) Config options: + # BLOCK_M + # BLOCK_N + # SCORE_MOD_IS_LINEAR: Is the score modifier linear? If so, we can lift the + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + + # Define Q Strides + stride_qz = {{stride("Q", 0)}} + stride_qh = {{stride("Q", 1)}} + stride_qm = {{stride("Q", 2)}} + stride_qk = {{stride("Q", 3)}} + # Define K Strides + stride_kz = {{stride("K", 0)}} + stride_kh = {{stride("K", 1)}} + stride_kn = {{stride("K", 2)}} + stride_kk = {{stride("K", 3)}} + # Define V Strides + stride_vz = {{stride("V", 0)}} + stride_vh = {{stride("V", 1)}} + stride_vn = {{stride("V", 2)}} + stride_vk = {{stride("V", 3)}} + + Z = {{size("Q", 0)}} + H = {{size("Q", 1)}} + N_CTX = {{size("Q", 2)}} + + qk_scale = 1.0 + MATMUL_PRECISION = Q.dtype.element_ty + + off_hz = tl.program_id(0) + off_z = off_hz // H # batch idx + off_h = off_hz % H # head idx + + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_kz + off_h * stride_kh + V += off_z * stride_vz + off_h * stride_vh + + # Asserting contiguous for now... + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_vz + off_h * stride_vh + + # TODO I think that this should be N_CTX/BLOCK_N blocks + for start_n in range(0, NUM_Q_BLOCKS): + # We are not doing the causal optimization yet allowing us to start further down the + # kv column + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_DMODEL) + + # initialize pointers to value-like data + q_ptrs = Q + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) + do_ptrs = DO + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk) + + # pointer to row-wise quantities in value-like data + D_ptrs = DELTA + off_hz * N_CTX + l_ptrs = LSE + off_hz * N_CTX + + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) + + # Key and Value stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + + for start_m in range(0, NUM_Q_BLOCKS * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + + if SCORE_MOD_IS_LINEAR: + qk_scale *= 1.44269504 + q = (q * qk_scale).to(MATMUL_PRECISION) + + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk = tl.dot(q, tl.trans(k.to(MATMUL_PRECISION)), acc=qk) + pre_mod_scores = qk + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = offs_m_curr[:, None] + n = offs_n[None, :] + {{ modification( + subgraph_number=0, + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(3) }} + # TODO: In the case that score_mod is linear, this can be LICMed + if not SCORE_MOD_IS_LINEAR: + qk *= 1.44269504 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + l_i = tl.load(l_ptrs + offs_m_curr) + p = tl.math.exp2(qk - l_i[:, None]) + + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(tl.trans(p.to(MATMUL_PRECISION)), do) + + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) # [BLOCKM, 1] + + # compute ds = p * (dp - delta[:, None]) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, tl.trans(v)) + ds = p * dp + + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + score="pre_mod_scores", + b="off_z", + h="off_h", + m="m", + n="n", + out="ds" + ) | indent_except_first(3) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # compute dk = dot(ds.T, q) + dk += tl.dot(tl.trans(ds.to(MATMUL_PRECISION)), q) + # compute dq + dq = tl.load(dq_ptrs) + dq += tl.dot(ds.to(MATMUL_PRECISION), k) + + # Store grad_query + tl.store(dq_ptrs, dq) + + # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + + # write-back + index_n = offs_n[:, None] + index_k = offs_k[None, :] + + # Store grad_key and grad_value + dv_ptrs = DV + (index_n * stride_vn + index_k * stride_vk) + tl.store(dv_ptrs, dv) + + # TODO generalize and add proper mask support + mask = (index_n != -1) & (index_k != -1) + {{store_output(("off_z", "off_h", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + + """, +) + + +# TODO: We probably also need a layout constraint? +@register_lowering( + torch.ops.higher_order.flex_attention_backward, type_promotion_kind=None +) +def flex_attention_backward(*args, **kwargs): + ( + query, + key, + value, + out, + logsumexp, + grad_out, + fw_graph, + joint_graph, + *other_buffers, + ) = args + + device = query.get_device() + dtype = query.get_dtype() + + fwd_placeholder_inps = [ + create_placeholder(name, dtype, device) for name, dtype in [ - ("score", query.get_dtype()), + ("score", dtype), ("b", torch.int32), ("h", torch.int32), ("m", torch.int32), ("n", torch.int32), ] ] - for node in subgraph.graph_module.graph.nodes: - # There are two classes of placeholder inpts that we need - # to handle differently. For the first n_scalar_inps inputs - # we expect that these placeholders were generated by the make_fx call - # in the flex Attention HOP. So we need to create a new placeholder - # TensorBox for each of these inputs. For the rest of the inputs we - # expect that these are lifted inputs that fill up the '*other_buffers' - # tuple and already have corresponding TensorBoxes passed in as args. - if node.op == "placeholder": - is_lifted_input = cnt >= len(scalar_inps) - env[node] = args[cnt - 1] if is_lifted_input else placeholder_inps[cnt] - cnt += 1 - elif node.op == "call_function": - # For call_function we use the defulat lowerings and pass in the - # already created TensorBoxes as args - from torch.utils._pytree import tree_map + fw_subgraph_buffer = build_subgraph_buffer( + args, fwd_placeholder_inps, fw_graph, graph_type=SubgraphType.JOINT_FWD + ) - env[node] = lowerings[node.target]( - *tree_map(lambda x: env[x] if x in env else x, node.args) - ) - elif node.op == "output": - # For the output node we need to create a ComputedBuffer - # which represents the actual score modification + joint_placeholder_inps = fwd_placeholder_inps + [ + create_placeholder("out", dtype, device) + ] + joint_subgraph_buffer = build_subgraph_buffer( + args, joint_placeholder_inps, joint_graph, graph_type=SubgraphType.JOINT_BWD + ) - output_buffer = env[node.args[0]] - assert isinstance(output_buffer.data, StorageBox), ( - "The output node for the flex attention subgraph must be a StorageBox, but got: ", - type(output_buffer), - ) - # Create the ComputedBuffer directly that will be inlined into the modification block - subgraph_buffer = ComputedBuffer( - name=None, - layout=FlexibleLayout( - device=output_buffer.data.get_device(), - dtype=output_buffer.data.get_dtype(), - size=output_buffer.data.get_size(), - ), - data=output_buffer.data.data, # type: ignore[arg-type] - ) + layout_k = FixedLayout( + key.get_device(), + key.get_dtype(), + key.get_size(), + make_contiguous_strides_for(key.get_size()), + ) - layout = FixedLayout( - output_buffer.get_device(), - query.get_dtype(), - query.get_size(), - make_contiguous_strides_for(query.get_size()), - ) - # see NOTE:[TritonTemplates with multiple outputs] - logsumexp_shape = query.get_size()[:-1] # [B, H, M] - logsumexp = empty_strided( - logsumexp_shape, - None, - dtype=torch.float32, # The logsumexp is always stored in fp32 regardless of the input dtype - device=output_buffer.get_device(), - ) - choices: List[Any] = [] - configs: List[Any] = [] - configs.append(_get_default_config(query)) - if config.max_autotune: - configs += [ - (128, 64, 4, 3), - (128, 128, 4, 3), - (128, 128, 8, 2), - (64, 128, 4, 3), - (64, 64, 4, 3), - ] - # Note, we don't need to pass in the captured buffers explicitly - # because they're implicitly added by the score_mod function - # We do need to explicitly pass it in for autotuning though. - for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: - sdpa_template.maybe_append_choice( - choices=choices, - input_nodes=[query, key, value, logsumexp], - layout=layout, - subgraphs=subgraph_buffer, - mutated_inputs=[ - logsumexp, - ], - num_stages=num_stages, - num_warps=num_warps, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=query.get_size()[-1], - # For now, we always assume the "sound" option - SCORE_MOD_IS_LINEAR=False, - ROWS_GUARANTEED_SAFE=False, - OUTPUT_LOGSUMEXP=True, - ) - inputs_for_autotuning = [query, key, value, logsumexp] + list(other_buffers) - return ( - autotune_select_algorithm( - "sdpa", choices, inputs_for_autotuning, layout - ), + # Create delta which will is needed for the bwd's kernel + mul_delta = lowerings[aten.mul](out, grad_out) + delta = lowerings[aten.sum](mul_delta, axis=-1) + + # see NOTE:[TritonTemplates with multiple outputs] + grad_query = full( + query.get_size(), 0.0, dtype=dtype, device=device + ) # torch.zeros equivalent + grad_query.realize() + grad_value = empty_strided(value.get_size(), None, dtype=dtype, device=device) + + choices: List[Any] = [] + configs: List[Tuple[int, int, int, int]] = [] + configs.append(_get_default_config_bwd(query)) + if config.max_autotune: + configs += [ + (128, 128, 4, 3), + (128, 128, 8, 1), + (64, 64, 4, 3), + (64, 64, 8, 1), + ] + + for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: + flex_attention_backward_template.maybe_append_choice( + choices=choices, + input_nodes=[ + query, + key, + value, + out, logsumexp, - ) - raise ValueError("TemplatedAttention was passed a subgraph with no output node!") + delta, + grad_out, + grad_query, + grad_value, + ], + layout=layout_k, # We use store_output only for grad_key + subgraphs=[fw_subgraph_buffer, joint_subgraph_buffer], + mutated_inputs=[grad_query, grad_value], + num_stages=num_stages, + num_warps=num_warps, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=query.get_size()[-1], + NUM_Q_BLOCKS=math.ceil(query.get_size()[-2] / BLOCK_M), + # For now, we always assume the "sound" option + SCORE_MOD_IS_LINEAR=False, + ) + inputs_for_autotuning = [ + query, + key, + value, + out, + logsumexp, + delta, + grad_out, + grad_query, + grad_value, + ] + list(other_buffers) + + grad_key = autotune_select_algorithm( + "flex_attention_backward", choices, inputs_for_autotuning, layout_k + ) + return ( + grad_query, + grad_key, + grad_value, + ) diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index d1550529bb8ee..4dd252f9b0683 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -103,7 +103,7 @@ def __init__( prefix_args=0, suffix_args=0, epilogue_fn=identity, - subgraphs=None, + subgraphs: Optional[List[ir.ComputedBuffer]] = None, *, index_dtype, ): @@ -114,7 +114,7 @@ def __init__( ) self.input_nodes = input_nodes self.output_node = output_node - self.named_input_nodes = {} + self.named_input_nodes = {} # type: ignore[var-annotated] self.defines = defines self.kernel_name = kernel_name self.template_mask = None @@ -128,10 +128,10 @@ def __init__( self.prefix_args = prefix_args self.suffix_args = suffix_args self.epilogue_fn = epilogue_fn - self.render_hooks = dict() + self.render_hooks = dict() # type: ignore[var-annotated] self.triton_meta: Optional[Dict[str, object]] = None - # For Templated Attention - self.subgraphs = subgraphs + # For Templated Attention this can be a list of ir.Subgraph + self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs def need_numel_args(self): return False @@ -271,19 +271,28 @@ def stride(self, name, index): val = self.named_input_nodes[name].get_stride()[index] return texpr(self.rename_indexing(val)) - def modification(self, **fixed_inputs) -> str: - """This function generates the code body to populate - a 'modification' placeholder within a template + def modification(self, subgraph_number: int, **fixed_inputs) -> str: + """This creates a modification function for a subgraph. + To use this inside a template, the first argument should specify which subgraph to codegen for - TODO come up with standardized way to modify templates, with - potential multiple modifications + Args: + subgraph_number (int): The index of the subgraph in self.subgraphs """ + assert isinstance(subgraph_number, int) + assert isinstance(self.subgraphs, list) + assert subgraph_number < len( + self.subgraphs + ), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" + + subgraph = self.subgraphs[subgraph_number] def add_input(name): return self.args.input(name) + name = f"PlaceholderSubstitution_{subgraph_number}" + class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = "PlaceholderSubstitution" + self.name = name def load(self, name: str, index: sympy.Expr): if name not in fixed_inputs: @@ -297,15 +306,14 @@ def load(self, name: str, index: sympy.Expr): def indirect_indexing(self, index_var, size, check): return sympy_index_symbol(str(index_var)) - # if self.modification_cache is None: with V.set_ops_handler(PlaceholderSubstitution(V.ops)): assert isinstance( - self.subgraphs, ir.ComputedBuffer - ), "Expected the subgraph to be a ComputedBuffer" - if isinstance(self.subgraphs.data, ir.InputBuffer): - out = self.subgraphs.data.make_loader()((1,)) + subgraph, ir.ComputedBuffer + ), f"Expected the subgraph to be a ComputedBuffer, got {type(subgraph)}" + if isinstance(subgraph.data, ir.InputBuffer): + out = subgraph.data.make_loader()((1,)) else: - out = self.subgraphs.data.inner_fn((1,)) + out = subgraph.data.inner_fn((1,)) self.codegen_body() self.body.writeline(f"{fixed_inputs['out']} = {out.value}") @@ -320,11 +328,18 @@ def store_output( indices: Union[List[Any], Tuple[Any]], val: str, mask: Optional[str] = None, + indent_width: int = 4, ): - """ - Hook called from template code to store the final output - (if the buffer hasn't been optimized away), then append any - epilogue fusions. + """Stores the final output and appends any epilogue fusions if the buffer hasn't been optimized away. + + Args: + indices (Union[List, Tuple]): The index for each dimension of the output. The dot product of + these indices and output strides must match `val`. + val (str): The value to store. + mask (Optional[str]): An optional mask to use for the store operation. If provided, this mask + will be applied to the store. + indent_width (int): The number of spaces to use for indentation. This is used when the call to + store_output is indented in the kernel definition. """ assert isinstance(indices, (list, tuple)) assert isinstance(val, str) @@ -348,7 +363,7 @@ def store_output( self.range_trees[0].lookup(sympy.Integer(1), sympy_product(lengths)).set_name( "xindex" ) - self.template_mask = mask + self.template_mask = mask # type: ignore[assignment] self.template_indices = indices output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) @@ -373,7 +388,7 @@ def store_output( def hook(): # more stuff might have been added since the codegen_body above self.codegen_body() - return textwrap.indent(self.body.getvalue(), " ").strip() + return textwrap.indent(self.body.getvalue(), " " * indent_width).strip() assert "" not in self.render_hooks self.render_hooks[""] = hook @@ -1420,7 +1435,7 @@ def log_results( result = timings[choice] if result: sys.stderr.write( - f" {choice.name} {result:.4f} ms {best_time/result:.1%}\n" + f" {choice.name} {result:.4f} ms {best_time / result:.1%}\n" ) else: sys.stderr.write( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 917dbfc3dd193..dfbcf50d31900 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -339,7 +339,7 @@ def print_performance( ): timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) took = torch.median(timings) / times - print(f"{took/baseline:.6f}") + print(f"{took / baseline:.6f}") return took diff --git a/torch/nn/attention/_flex_attention.py b/torch/nn/attention/_flex_attention.py index 1acfab57a62ce..9f9dcc1ae5d74 100644 --- a/torch/nn/attention/_flex_attention.py +++ b/torch/nn/attention/_flex_attention.py @@ -97,6 +97,8 @@ def score_mod( raise ValueError( "NYI: The target sequence length (L) of the query tensor must match the source sequence length (S) of the key tensor." ) + if query.size(-2) % 128 != 0: + raise ValueError("NYI: S and L must be a multiple of 128") if not torch._dynamo.is_dynamo_supported(): raise RuntimeError("flex_attention requires dynamo support.") @@ -150,7 +152,7 @@ def _rel_causal( token_q: torch.Tensor, token_kv: torch.Tensor, ) -> torch.Tensor: - return torch.where(token_q <= token_kv, score + (token_q - token_kv), float("-inf")) + return torch.where(token_q >= token_kv, score + (token_q - token_kv), float("-inf")) def _generate_alibi_bias(num_heads: int): diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 4772fb42a9631..df78812a65045 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -118,9 +118,9 @@ def score_mod(score, b, h, m, n): return score + h yield SampleInput( - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), - make_arg(2, 2, 64, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), + make_arg(2, 2, 128, 8, low=0.1, high=2), score_mod, )