diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 8fa2a506f35f3..202fe5a256499 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -341,7 +341,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) { constexpr auto dense_constraints = array_of( check_batch_size_and_num_heads_dense, check_nonzero_sequence_lengths_dense, - check_last_dim_stride_equals_1_dense); + check_last_dim_stride_equals_1_dense); for (auto& constraint : dense_constraints) { if (!constraint(params, debug)) { return false; @@ -399,7 +399,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { constexpr auto dense_constraints = array_of( check_batch_size_and_num_heads_dense, check_nonzero_sequence_lengths_dense, - check_last_dim_stride_equals_1_dense); + check_last_dim_stride_equals_1_dense); for (auto& constraint : dense_constraints) { if (!constraint(params, debug)) { return false; diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp index 85b0d304cdf03..f2159f8f0fdbd 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.cpp @@ -46,7 +46,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) { check_attn_mask_shape, check_head_dim_size_cpp, check_nonzero_sequence_lengths_dense, - check_last_dim_stride_equals_1_dense); + check_last_dim_stride_equals_1_dense); for (auto& constraint : constraints) { if (!constraint(params, debug)) { return false; diff --git a/aten/src/ATen/native/transformers/sdp_utils_cpp.h b/aten/src/ATen/native/transformers/sdp_utils_cpp.h index 2099ca01a1c26..d5d136c79609e 100644 --- a/aten/src/ATen/native/transformers/sdp_utils_cpp.h +++ b/aten/src/ATen/native/transformers/sdp_utils_cpp.h @@ -431,6 +431,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool return true; } +template inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) { // The stride checking for NestedTensors is done within the kernel // And .contiguous will be called if needed @@ -439,6 +440,13 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool // fused_attention have stride 1 bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 && params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1; + + // https://github.com/pytorch/pytorch/issues/116333 + // If the head_dim is size 1 the stride won't matter, but we + // check this condition before padding the head_dim to 1 + if (ignore_singleton_dim){ + qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1; + } bool mask_stride_equal_1 = params.attn_mask.has_value() ? params.attn_mask.value().sym_stride(-1) == 1 : true; diff --git a/test/test_transformers.py b/test/test_transformers.py index 7b5039b85eceb..a3577248baa16 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2121,6 +2121,16 @@ def test_mem_eff_attention_non_contig_mask_bug(self, device): max_diff = (out - out_contig).abs().mean() self.assertTrue(max_diff.item() < 1e-7) + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system") + def test_singelton_head_dim_stride_ne_1(self, device): + query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device) + query = query.transpose(-1, -2) + key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) + value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + scaled_dot_product_attention(query, key, value) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False])