Skip to content

Commit 9fabef0

Browse files
committed
fix test ordr
1 parent 332816c commit 9fabef0

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

test/test_transformers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2576,7 +2576,7 @@ def test_fused_attention_different_dk_dv(self, device):
25762576

25772577

25782578
@skipIfRocm # No cuDNN Attention
2579-
@unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cuDNN Attention is not supported on this system")
2579+
@unittest.skipIf(True, "broken as of cuDNN 9.10")
25802580
def test_cudnn_attention_fail_d128(self, device):
25812581
# Test that cuDNN attention dispatching correctly bails out on d > 128
25822582
b, h = 1, 2
@@ -2591,7 +2591,6 @@ def test_cudnn_attention_fail_d128(self, device):
25912591
ISSM90 = device_cap == (9, 0)
25922592
ISSM100 = device_cap == (10, 0)
25932593
with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]):
2594-
# SM90/100 support d <= 256 as of cuDNN 9.5.1+
25952594
if (ISSM90 or ISSM100) and torch.backends.cudnn.version() >= 90501:
25962595
torch.nn.functional.scaled_dot_product_attention(q, k, v)
25972596
else:
@@ -3030,7 +3029,9 @@ def test_fused_sdp_choice(self, device, type: str):
30303029
device_capability = None
30313030
if "cuda" in str(device):
30323031
device_capability = torch.cuda.get_device_capability()
3033-
prefer_cudnn = device_capability and (device_capability == (9, 0) or device_capability == (10, 0))
3032+
prefer_cudnn = False
3033+
# TODO(eqy): uncomment the following condition
3034+
# device_capability and (device_capability == (9, 0) or device_capability == (10, 0))
30343035

30353036
# TODO we are currently disabling this by default, lets assert that this returns
30363037
# FlashAttention, we need to change when we make remove opt-in for cudnn

0 commit comments

Comments
 (0)