Skip to content

Commit

Permalink
Add backwards support to FlexAttention (#123902)
Browse files Browse the repository at this point in the history
# Summary
This is part one of adding backwards support to FlexAttention.

This PR focuses on the eager implementation and wiring up enough of the templated_attention_backward(name change soon 馃槈) to get through aot_eager.

Notably this does not actually wire up the triton template just yet in order to make this PR easier to review. That will be the next follow up PR.

#### Structure
We pass both the forward and backward graph to the backwardsHOP since these are both needed to be inlined into the calculation for backwards:
- the forward graph is needed in order to re-compute the scores
- the joint graph is needed in order to construct the correct gradients  post softmax_grad calc

### Attatched AOT Graph
https://gist.github.com/drisspg/ce4c041f8df8a5a7983c5174705cf2b5

Pull Request resolved: #123902
Approved by: https://github.com/Chillee
  • Loading branch information
drisspg authored and pytorchmergebot committed Apr 29, 2024
1 parent 720e5f3 commit 8c21925
Show file tree
Hide file tree
Showing 7 changed files with 707 additions and 38 deletions.
222 changes: 215 additions & 7 deletions test/inductor/test_templated_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from unittest.mock import patch

import torch

from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm
from torch._higher_order_ops.templated_attention import (
templated_attention as templated_attention_hop,
)
Expand Down Expand Up @@ -56,6 +58,9 @@ def create_attention(score_mod):
if common_utils.TEST_WITH_ROCM:
test_dtypes = [torch.float32]


# --------- Useful score mod functions for testing ---------

test_score_mods = [
_identity,
_causal,
Expand All @@ -65,9 +70,56 @@ def create_attention(score_mod):
]


def _causal_mod(score, b, h, token_q, token_kv):
return torch.where(token_q >= token_kv, score, float("-inf"))
def _times_two(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * 2


def _squared(score, b, h, m, n):
"""Joint graph needed for correctness"""
return score * score


def _head_offset(dtype: torch.dtype):
"""Captured Buffer
Note: this builds a score_mod with index of a type
"""
head_offset = torch.rand(H, device="cuda", dtype=dtype)

def score_mod(score, b, h, m, n):
return score * index(head_offset, [h])

return score_mod


def _trig(score, b, h, m, n):
"""Joint graph needed for correctness"""
return torch.sin(torch.cos(score)) + torch.tan(b)


def _trig2(score, b, h, m, n):
"""Branching joint graph"""
cos_score = torch.cos(score)
sin_score = torch.sin(score)
z = cos_score * sin_score + torch.tan(b)
return z


def _buffer_reduced(dtype: torch.dtype):
"""Reduction in captured buffer"""
batch_offsets = torch.rand(B, 8, device="cuda", dtype=dtype)

def score_mod(score, b, h, m, n):
batch_vals = index(batch_offsets, [b])
return score + batch_vals.sum()

return score_mod


captured_buffers_map = {
"_head_offset": _head_offset,
"_buffer_reduced": _buffer_reduced,
}

B = 4
H = 8
Expand Down Expand Up @@ -106,6 +158,14 @@ def run_test(self, score_mod: Callable, dtype: torch.dtype = torch.float16):
def test_builtin_score_mods(self, dtype: torch.dtype, score_mod: Callable):
self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_skip_odd_keys(self, dtype: torch.dtype):
def score_mod(score, b, h, q, kv):
return torch.where(kv % 2 == 0, score, float("-inf"))

self.run_test(score_mod, dtype)

@supported_platform
@common_utils.parametrize("dtype", test_dtypes)
def test_function_composition(self, dtype: torch.dtype):
Expand Down Expand Up @@ -250,7 +310,7 @@ def njt_score_mod(qk, b, h, q, kv):

return njt_score_mod

causal_njt = create_njt_wrapper(_causal_mod, offsets, seq_idx)
causal_njt = create_njt_wrapper(_causal, offsets, seq_idx)

self.run_test(causal_njt, dtype)

Expand All @@ -264,10 +324,11 @@ def test_backwards_fails(self):
requires_grad=True,
)
q, k, v = make_tensor(), make_tensor(), make_tensor()
out = _templated_attention(q, k, v, _identity)
func = torch.compile(_templated_attention, backend="inductor", fullgraph=True)
with self.assertRaisesRegex(
RuntimeError, "Autograd not implemented for templated_attention"
AssertionError, "templated_attention_backward is not an OpOverload"
):
out = func(q, k, v, _identity)
out.backward(torch.ones_like(out))

@supported_platform
Expand Down Expand Up @@ -319,6 +380,14 @@ def test_logsumexp_correctness(self, dtype, score_mod):
def sdpa_hop(q, k, v, score_mod):
return templated_attention_hop(q, k, v, score_mod)

@torch.compile(backend="aot_eager")
def eager_sdpa_hop(q, k, v, score_mod):
"""The main entrypoint for FlexAttention doesnt return LSE.
Besides dropping LSE it also ensures that the hop is compiled with aot-eager
backend. We need to replicate this.
"""
return templated_attention_hop(q, k, v, score_mod)

make_tensor = functools.partial(
torch.randn,
(B, H, S, D),
Expand All @@ -328,7 +397,7 @@ def sdpa_hop(q, k, v, score_mod):
)
q, k, v = make_tensor(), make_tensor(), make_tensor()

ref_out, ref_lse = templated_attention_hop(
ref_out, ref_lse = eager_sdpa_hop(
q.to(torch.float64), k.to(torch.float64), v.to(torch.float64), score_mod
)
compiled_out, compiled_lse = sdpa_hop(q, k, v, score_mod)
Expand All @@ -341,7 +410,7 @@ def sdpa_hop(q, k, v, score_mod):
# x_ref = sum(_i e^(scores[i]))
# x_compiled = sum(_i 2^(log2(e) * scores[i]))

self.assertTrue(ref_lse.dtype == torch.float32)
self.assertTrue(ref_lse.dtype == torch.float64)
self.assertTrue(compiled_lse.dtype == torch.float32)
ref_lse = ref_lse * torch.log2(torch.tensor(torch.e))

Expand Down Expand Up @@ -401,6 +470,145 @@ def func(q, k, v, score_mod):
# Ensure that two kernels are generated
FileCheck().check_count(".run(", 2, True).run(code[0])

@supported_platform
@common_utils.parametrize(
"score_mod", [_identity, _causal, _times_two, _squared, _trig, _trig2]
)
def test_aot_eager_gradcheck(self, score_mod):
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()

func = torch.compile(_templated_attention, backend="aot_eager", fullgraph=True)

self.assertTrue(
torch.autograd.gradcheck(
func, (query, key, value, score_mod), raise_exception=True
)
)

@supported_platform
@common_utils.parametrize("score_mod_name", ["_head_offset", "_buffer_reduced"])
@common_utils.parametrize("mode", ["eager", "aot_eager"])
def test_captured_score_mod_aot_eager_gradcheck(
self, score_mod_name: str, mode: str
):
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()

func = torch.compile(_templated_attention, backend=mode, fullgraph=True)
score_mod = captured_buffers_map[score_mod_name](torch.float64)

self.assertTrue(
torch.autograd.gradcheck(
func, (query, key, value, score_mod), raise_exception=True
)
)

@supported_platform
def test_fw_bw_graph_correctness(self):
cnt = CompileCounterWithBackend("aot_eager")
make_tensor = functools.partial(
torch.randn,
(2, 2, 8, 4),
device="cuda",
dtype=torch.float64,
requires_grad=True,
)
query, key, value = make_tensor(), make_tensor(), make_tensor()

func = torch.compile(_templated_attention, backend=cnt, fullgraph=True)
out = func(query, key, value, _squared)
out.sum().backward()
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(len(cnt.graphs), 1)
graph = cnt.graphs[0]
norm_graph = normalize_gm(graph.print_readable(print_output=False))
self.assertExpectedInline(
norm_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_args_0_ : torch.Tensor, L_args_1_ : torch.Tensor, L_args_2_ : torch.Tensor):
l_args_0_ = L_args_0_
l_args_1_ = L_args_1_
l_args_2_ = L_args_2_
new_empty = l_args_0_.new_empty([], requires_grad = True)
new_empty_1 = l_args_0_.new_empty([], dtype = torch.int32)
new_empty_2 = l_args_0_.new_empty([], dtype = torch.int32)
new_empty_3 = l_args_0_.new_empty([], dtype = torch.int32)
new_empty_4 = l_args_0_.new_empty([], dtype = torch.int32)
templated_attention_0 = self.templated_attention_0
templated_attention = torch.ops.higher_order.templated_attention(l_args_0_, """
+ """l_args_1_, l_args_2_, templated_attention_0); l_args_0_ = l_args_1_ = l_args_2_ = templated_attention_0 = None
out = templated_attention[0]; templated_attention = None
return (out,)
class GraphModule(torch.nn.Module):
def forward(self, new_empty, new_empty_1, new_empty_2, new_empty_3, new_empty_4):
mul = new_empty * new_empty; new_empty = None
return mul
""",
)
# Save the AOT graphs
aot_graphs = []
from torch._inductor import compile_fx

def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
aot_graphs.append(graph)
return graph

backend = functools.partial(
compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
)
func = torch.compile(func, backend=backend, fullgraph=True)
out = func(query, key, value, _squared)
out.sum().backward()

joint_graph = normalize_gm(aot_graphs[1].print_readable(print_output=False))

self.assertExpectedInline(
joint_graph,
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f64[2, 2, 8, 4]", primals_2: "f64[2, 2, 8, 4]", primals_3: "f64[2, 2, 8, 4]", """
+ """alias_5: "f64[2, 2, 8, 4]", alias_7: "f32[2, 2, 8]", tangents_1: "f64[2, 2, 8, 4]"):
fw_graph = self.fw_graph
joint_graph = self.joint_graph
templated_attention_backward = torch.ops.higher_order.templated_attention_backward(primals_1, primals_2, """
+ """primals_3, alias_5, alias_7, tangents_1, fw_graph, joint_graph); primals_1 = primals_2 = primals_3 = alias_5 """
+ """= alias_7 = tangents_1 = fw_graph = joint_graph = None
getitem_2: "f64[2, 2, 8, 4]" = templated_attention_backward[0]
getitem_3: "f64[2, 2, 8, 4]" = templated_attention_backward[1]
getitem_4: "f64[2, 2, 8, 4]" = templated_attention_backward[2]; templated_attention_backward = None
return [getitem_2, getitem_3, getitem_4]
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1); arg0_1 = None
return mul
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i32[]", arg4_1: "i32[]", arg5_1: "f64[]"):
mul: "f64[]" = torch.ops.aten.mul.Tensor(arg0_1, arg0_1)
mul_1: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1)
mul_2: "f64[]" = torch.ops.aten.mul.Tensor(arg5_1, arg0_1); arg5_1 = arg0_1 = None
add: "f64[]" = torch.ops.aten.add.Tensor(mul_2, mul_1); mul_2 = mul_1 = None
return [add, None, None, None, None]
""",
)


common_utils.instantiate_parametrized_tests(TestTemplatedSDPA)

Expand Down
19 changes: 10 additions & 9 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ def create_scalar():

proxy_args = (body_node,) + lifted_args

return proxy_args, {}
return proxy_args

def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
Expand All @@ -1541,29 +1541,30 @@ def call_function(

query, key, value, score_mod = self.normalize_to_args(args, kwargs)

p_args, p_kwargs = self.create_wrapped_node(tx, query, score_mod)
p_args = self.create_wrapped_node(tx, query, score_mod)
proxied_args = [query, key, value]

# Store the invocation as a call
# Norm_kwargs contains the score_function and we dont want to proxy this because
# Proxying user defined functions is not supported.
inp_args, _ = proxy_args_kwargs(proxied_args, {})

# Why is this here? Unlike other HOPs, the subgrpah's output for this hop is unrelated
# to what the overall HOP returns, we create the correct output proxy by calling the
# hop (self.value) with the example values.
query_meta = query.as_proxy().node.meta["example_value"]
logsumexp_shape = query_meta.size()[:-1] # [B, H, M]
with torch._guards.TracingContext.try_get().fake_mode:
example_args = pytree.tree_map_only(
torch.fx.Proxy, lambda a: a.node.meta["example_value"], inp_args
out_meta = torch.empty_like(
query_meta, memory_format=torch.contiguous_format
)
example_value = self.value(*example_args, score_mod)
lse_meta = query_meta.new_empty(logsumexp_shape, dtype=torch.float32)
example_value = (out_meta, lse_meta)

return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=inp_args + p_args,
kwargs=p_kwargs,
kwargs={},
),
example_value=example_value,
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .cond import cond
from .while_loop import while_loop
from .templated_attention import templated_attention
from .templated_attention import templated_attention, templated_attention_backward
2 changes: 1 addition & 1 deletion torch/_higher_order_ops/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_fw_bw_graph(f, num_mapped_args, *args):
mapped_xs = args[:num_mapped_args]
pos_args = args[num_mapped_args:]

# Note: We create "clean" environments for make_fx by suspending all dispatch keys
# Note:[HOP create fw_bw graph] We create "clean" environments for make_fx by suspending all dispatch keys
# between Autograd and Python key. Currently, we only suspend functionalization but more can be
# added when required. Will encounter two problems if we don't suspend functionalization:
#
Expand Down
Loading

0 comments on commit 8c21925

Please sign in to comment.