diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index 9ede05918e..758332b315 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -439,14 +439,10 @@ def _execute_cutlass_blackwell_attn_varlen( seqlen_k, batch_size, is_mqa, - window_size, - sm_scale, ) for seqlen_k in [64, 128, 256, 1024] for batch_size in [1, 2] for is_mqa in [True] - for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] - for sm_scale in [None, 1.0 / 128] ] ) def test_decode( @@ -454,8 +450,6 @@ def test_decode( seqlen_k: int, batch_size: int, is_mqa: bool, - window_size: tuple[int, int], - sm_scale: Optional[float], q_heads: int = 8, dtype: torch.dtype = torch.float8_e4m3fn, ) -> None: @@ -473,10 +467,12 @@ def test_decode( head_dim=128, dtype=dtype, causal=causal, - window_size=window_size, + # Decode kernel does not support sliding window attention yet + window_size=(-1, -1), fwd_only=True, deterministic=False, - sm_scale=sm_scale, + # Decode kernel does not support sm_scale + sm_scale=None, ) @skip_cuda_lt_sm100