Skip to content

Commit

Permalink
Adds support up until aot_eager for running FlexAttention backwards
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Apr 24, 2024
1 parent 9c2ac44 commit c797070
Show file tree
Hide file tree
Showing 4 changed files with 479 additions and 16 deletions.
29 changes: 25 additions & 4 deletions test/inductor/test_templated_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def _causal_mod(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))


def _times_two_mod(score, b, h, m, n):
return score * 2


class TestTemplatedSDPA(InductorTestCase):
def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
sdpa_partial = create_attention(score_mod)
Expand Down Expand Up @@ -160,7 +164,7 @@ def score_mod(score, b, h, m, n):
self.run_test(score_mod, dtype)

@supported_platform
def test_backwards_fails(self):
def test_backwards_fails_inductor(self):
make_tensor = functools.partial(
torch.randn,
(4, 8, 2048, 64),
Expand All @@ -169,10 +173,11 @@ def test_backwards_fails(self):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
out = _templated_attention(q, k, v, _identity_mod)
func = torch.compile(_templated_attention, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
RuntimeError, "Autograd not implemented for templated_attention"
AssertionError, "templated_attention_backward is not an OpOverload"
):
out = func(q, k, v, _identity_mod)
out.backward(torch.ones_like(out))

@supported_platform
Expand Down Expand Up @@ -231,7 +236,7 @@ def sdpa_hop(q, k, v, score_mod):
# x_ref = ∑_i e^(scores[i])
# x_compiled = ∑_i 2^(log2(e) * scores[i])

self.assertTrue(ref_lse.dtype == torch.float32)
self.assertTrue(ref_lse.dtype == torch.float64)
self.assertTrue(compiled_lse.dtype == torch.float32)
ref_lse = ref_lse * torch.log2(torch.tensor(torch.e))

Expand Down Expand Up @@ -291,6 +296,22 @@ def func(q, k, v, score_mod):
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])

@supported_platform
@common_utils.parametrize("score_mod", [_identity_mod, _causal_mod, _times_two_mod])
def test_aot_eager_gradcheck(self, score_mod):
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()

func = torch.compile(_templated_attention, backend="aot_eager", fullgraph=True)

self.assertTrue(torch.autograd.gradcheck(func, (query, key, value, score_mod)))


common_utils.instantiate_parametrized_tests(TestTemplatedSDPA)

Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,7 +1425,8 @@ def call_function(
# Proxying user defined functions is not supported.
inp_args, _ = proxy_args_kwargs(proxied_args, {})

# Why is this here? Unlike other HOPs, the subgrpah's output for this hop is unrelated
# Note:[TemplatedAttention out example value]
# Why is this here? Unlike other HOPs, the subgraph's output for this hop is unrelated
# to what the overall HOP returns, we create the correct output proxy by calling the
# hop (self.value) with the example values.
with torch._guards.TracingContext.try_get().fake_mode:
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]

# Note: We create "clean" environments for make_fx by suspending all dispatch keys
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
Expand Down
Loading

0 comments on commit c797070

Please sign in to comment.