Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -439,23 +439,17 @@ def _execute_cutlass_blackwell_attn_varlen(
seqlen_k,
batch_size,
is_mqa,
window_size,
sm_scale,
)
for seqlen_k in [64, 128, 256, 1024]
for batch_size in [1, 2]
for is_mqa in [True]
for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)]
for sm_scale in [None, 1.0 / 128]
]
)
def test_decode(
self,
seqlen_k: int,
batch_size: int,
is_mqa: bool,
window_size: tuple[int, int],
sm_scale: Optional[float],
q_heads: int = 8,
dtype: torch.dtype = torch.float8_e4m3fn,
) -> None:
Expand All @@ -473,10 +467,12 @@ def test_decode(
head_dim=128,
dtype=dtype,
causal=causal,
window_size=window_size,
# Decode kernel does not support sliding window attention yet
window_size=(-1, -1),
fwd_only=True,
deterministic=False,
sm_scale=sm_scale,
# Decode kernel does not support sm_scale
sm_scale=None,
)

@skip_cuda_lt_sm100
Expand Down
Loading