Skip to content

Commit

Permalink
match sdpa patterns from HF (#100609)
Browse files Browse the repository at this point in the history
Adds sdpa patterns seen in HF models.

To actually make the patterns match, we need constant folding to remove addition of all-zeros mask, and figure out what to do with low mem dropout.

Pull Request resolved: #100609
Approved by: https://github.com/jansel
  • Loading branch information
ngimel authored and pytorchmergebot committed May 17, 2023
1 parent 8e51521 commit f33725b
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 2 deletions.
90 changes: 88 additions & 2 deletions test/inductor/test_fused_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from torch._dynamo.utils import counters
from torch._inductor import config
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_SDPA
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_SDPA,
SM80OrLater,
)
from torch.testing._internal.common_utils import IS_LINUX, TEST_WITH_ROCM
from torch.testing._internal.inductor_utils import HAS_CUDA

Expand Down Expand Up @@ -59,7 +62,8 @@ def _check_common(
if contains:
# many of the patterns get re-expanded in dispatcher
self.assertIn(
"aten._scaled_dot_product_efficient_attention", source_code
"aten._scaled_dot_product",
source_code,
)
if not has_dropout:
self.assertEqual(result1, result2, atol=atol, rtol=1.3e-6)
Expand Down Expand Up @@ -199,6 +203,88 @@ def sfdp_pattern_6(query, key, value):

self._check_common(sfdp_pattern_6, contains=False, has_dropout=True)

def test_sdpa_rewriter_7(self):
def sfdp_pattern_7(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# very small dropout to make sure test passes
attn_weight = torch.dropout(attn_weight, 0.0001, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v

args = (
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
)

self._check_common(sfdp_pattern_7, args, contains=SM80OrLater)

def test_sdpa_rewriter_8(self):
def sfdp_pattern_8(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v

args = (
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
)

self._check_common(sfdp_pattern_8, args)

def test_sdpa_rewriter_9(self):
def sfdp_pattern_9(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
# very low dropout to make test pass
attn_weight = torch.dropout(attn_weight, 0.0001, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v

args = (
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
)

self._check_common(sfdp_pattern_9, args, contains=SM80OrLater)

def test_sdpa_rewriter_10(self):
def sfdp_pattern_10(query, key, value):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v

args = (
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
torch.empty((2, 8, 4, 16), device="cuda", dtype=torch.half),
)

self._check_common(sfdp_pattern_10, args)

def test_pattern_fails_with_tensor_factor(self):
# https://github.com/pytorch/pytorch/issues/99124
class Model(torch.nn.Module):
Expand Down
154 changes: 154 additions & 0 deletions torch/_inductor/fx_passes/fuse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,127 @@ def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p):
)


def _sfdp_pattern_7(query, key, value, dropout_p):
# in real workloads inputs to matmul are permuted
# causing matmul to expand to a series of expand and clone calls
# we want the same to happen during pattern tracing
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v


def _sfdp_replacement_7(query, key, value, dropout_p):
# sdpa prefers inputs in permuted format
# it makes a copy to put them in this format
# if they aren't already
# to make replacement efficient ensure that inputs to sdpa
# are in required order
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return aten.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=dropout_p,
is_causal=False,
)


def _sfdp_pattern_8(query, key, value):
# no dropout version of pattern 7
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v


def _sfdp_replacement_8(query, key, value):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return aten.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=0.0,
is_causal=False,
)


def _sfdp_pattern_9(query, key, value, dropout_p):
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, True)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v


def _sfdp_replacement_9(query, key, value, dropout_p):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return aten.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=dropout_p,
is_causal=False,
)


def _sfdp_pattern_10(query, key, value):
# no dropout version of 9
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
q = q / math.sqrt(q.size(-1))
div = q @ k.transpose(-2, -1)
div = div.to(torch.float32)
attn_weight = torch.softmax(div, dim=-1)
attn_weight = attn_weight.to(torch.float16)
return attn_weight @ v


def _sfdp_replacement_10(query, key, value):
counters["inductor"]["fuse_attention"] += 1
q = query.permute(0, 2, 1, 3)
k = key.permute(0, 2, 1, 3)
v = value.permute(0, 2, 1, 3)
return aten.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None, # attn_mask,
dropout_p=0.0,
is_causal=False,
)


# TODO(jansel): make these pattern work with lowmem_dropout=True


def _sfdp_params_check(match):
assert all(k in match.kwargs for k in ("query", "key", "value"))
query = match.kwargs["query"].meta["val"]
Expand Down Expand Up @@ -194,6 +315,9 @@ def _sfdp_init():
# sizes/values don't actually matter for initial trace
# once we get a possible match we re-trace with the actual values and verify the match still holds
g = functools.partial(torch.empty, (2, 4, 8, 16), device=device, requires_grad=True)
gp = functools.partial(
torch.empty, (2, 8, 4, 16), device=device, requires_grad=True, dtype=torch.half
)
b = functools.partial(torch.empty, (1, 1, 8, 8), device=device)
c = functools.partial(torch.tensor, 2.0, device=device)
# workaround https://github.com/pytorch/pytorch/issues/97894
Expand Down Expand Up @@ -243,6 +367,34 @@ def _sfdp_init():
d,
_sfdp_params_check,
),
(
_sfdp_pattern_7,
_sfdp_replacement_7,
[gp(), gp(), gp()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_8,
_sfdp_replacement_8,
[gp(), gp(), gp()],
{},
_sfdp_params_check,
),
(
_sfdp_pattern_9,
_sfdp_replacement_9,
[gp(), gp(), gp()],
d,
_sfdp_params_check,
),
(
_sfdp_pattern_10,
_sfdp_replacement_10,
[gp(), gp(), gp()],
{},
_sfdp_params_check,
),
]:
args = [*args, *workaround.values()]
register_replacement(
Expand All @@ -263,3 +415,5 @@ def _sfdp_init():
extra_check=extra_check,
scalar_workaround=workaround,
)

counters["inductor"].clear() # clear view matches encountered during sdpa tracing
14 changes: 14 additions & 0 deletions torch/_inductor/fx_passes/joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,17 @@ def pointless_convert(match: Match, arg, dtype1, dtype2):
repl.meta.update(node.meta)
node.replace_all_uses_with(repl)
match.erase_nodes(graph)


@register_graph_pattern(
CallFunction(torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")),
pass_dict=patterns,
)
def pointless_view(match: Match, arg, size):
"""Remove no-op view"""
graph = match.graph
node = match.output_node()
arg_size = list(node.args[0].meta["val"].shape)
if size == arg_size:
node.replace_all_uses_with(node.args[0])
match.erase_nodes(graph)
12 changes: 12 additions & 0 deletions torch/_inductor/pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,18 @@ def record_joint_graph(joint_graph, inputs, **kwargs):
enable_log=False,
)(*args)

from .fx_passes.joint_graph import pointless_view

matcher_pass = PatternMatcherPass()

pattern = CallFunction(
torch.ops.aten.view.default, KeywordArg("arg"), KeywordArg("size")
)
GraphPatternEntry(
pattern=pattern, handler=pointless_view, extra_check=_return_true
).register(matcher_pass.patterns)
matcher_pass.apply(gm.graph)

# remove in/out specs
gm.graph._codegen = torch.fx.graph.CodeGen()
gm.graph.eliminate_dead_code()
Expand Down

1 comment on commit f33725b

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #100609 on behalf of https://github.com/izaitsevfb due to Based on #100064, which needs to be reverted due to diff-train issues. (comment)

Please sign in to comment.