Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/dynamo/test_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch._dynamo.testing
from torch._dynamo.testing import CompileCounter
from torch.backends.cuda import SDPAParams
from torch.nn.attention import _cur_sdpa_kernel_backends, sdpa_kernel, SDPBackend


@contextlib.contextmanager
Expand Down Expand Up @@ -99,6 +100,43 @@ def fn(q, k, v, m):
self.assert_ref_equals_params(o, expected)
self.assertEqual(counter.frame_count, 1)

def test_sdpa_c_functions_no_graph_break(self):
counter = CompileCounter()

@torch.compile(fullgraph=True, backend=counter)
def test_cur_sdpa_kernel_backends():
return _cur_sdpa_kernel_backends()

result = test_cur_sdpa_kernel_backends()

self.assertIsInstance(result, list)
self.assertEqual(counter.frame_count, 1)

def test_sdpa_kernel_decorator_with_compile(self):
SDPA_BACKEND_PRIORITY = [
SDPBackend.MATH,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.FLASH_ATTENTION,
]

@sdpa_kernel(backends=SDPA_BACKEND_PRIORITY, set_priority=True)
def scaled_dot_product_attention(q, k, v, *args, **kwargs):
return torch.nn.functional.scaled_dot_product_attention(
q, k, v, *args, **kwargs
)

counter = CompileCounter()

@torch.compile(fullgraph=True, backend=counter)
def f(x):
return scaled_dot_product_attention(x, x, x)

x = torch.rand(128, 64, 64, 256, dtype=torch.float16)
result = f(x)

self.assertEqual(result.shape, x.shape)
self.assertEqual(counter.frame_count, 1)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
2 changes: 2 additions & 0 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@
"torch._C._get_mem_efficient_sdp_enabled",
"torch._C._get_mkldnn_enabled",
"torch._C._get_cudnn_sdp_enabled",
"torch._C._get_overrideable_sdp_enabled",
"torch._C._set_sdp_use_cudnn",
"torch._C._get_mobile_model_contained_types_from_buffer",
"torch._C._get_mobile_model_contained_types",
Expand Down Expand Up @@ -1219,6 +1220,7 @@
"torch._C._set_sdp_use_math",
"torch._C._set_math_sdp_allow_fp16_bf16_reduction",
"torch._C._set_sdp_use_mem_efficient",
"torch._C._set_sdp_use_overrideable",
"torch._C._set_should_use_format_with_string_table",
"torch._C._set_sm_carveout_experimental",
"torch._C._set_storage_access_error_msg",
Expand Down
5 changes: 5 additions & 0 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2982,6 +2982,11 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
torch.backends.cuda.is_flash_attention_available,
torch.backends.cuda.can_use_flash_attention,
torch.backends.cuda.can_use_efficient_attention,
torch._C._get_cudnn_sdp_enabled,
torch._C._get_flash_sdp_enabled,
torch._C._get_mem_efficient_sdp_enabled,
torch._C._get_math_sdp_enabled,
torch._C._get_overrideable_sdp_enabled,
"is_integer",
]
+ list(supported_const_comparison_op_values.keys())
Expand Down
Loading