Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes last_dim stride check for singleton dimensions #117001

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
check_last_dim_stride_equals_1_dense<true /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down Expand Up @@ -399,7 +399,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
constexpr auto dense_constraints = array_of<bool (*)(sdp_params const&, bool)>(
check_batch_size_and_num_heads_dense,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim=*/>);
for (auto& constraint : dense_constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/sdp_utils_cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ bool use_flash_attention_cpp(sdp_params const& params, bool debug) {
check_attn_mask_shape,
check_head_dim_size_cpp,
check_nonzero_sequence_lengths_dense,
check_last_dim_stride_equals_1_dense);
check_last_dim_stride_equals_1_dense<false /*ignore_singleton_dim*/>);
for (auto& constraint : constraints) {
if (!constraint(params, debug)) {
return false;
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/transformers/sdp_utils_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ inline bool check_nonzero_sequence_lengths_dense(sdp_params const& params, bool
return true;
}

template<bool ignore_singleton_dim>
inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool debug) {
// The stride checking for NestedTensors is done within the kernel
// And .contiguous will be called if needed
Expand All @@ -439,6 +440,13 @@ inline bool check_last_dim_stride_equals_1_dense(sdp_params const& params, bool
// fused_attention have stride 1
bool qkv_strides_equal_1 = params.query.sym_stride(-1) == 1 &&
params.key.sym_stride(-1) == 1 && params.value.sym_stride(-1) == 1;

// https://github.com/pytorch/pytorch/issues/116333
// If the head_dim is size 1 the stride won't matter, but we
// check this condition before padding the head_dim to 1
if (ignore_singleton_dim){
qkv_strides_equal_1 = qkv_strides_equal_1 || params.query.sym_size(-1) == 1;
}
bool mask_stride_equal_1 = params.attn_mask.has_value()
? params.attn_mask.value().sym_stride(-1) == 1
: true;
Expand Down
10 changes: 10 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,16 @@ def test_mem_eff_attention_non_contig_mask_bug(self, device):
max_diff = (out - out_contig).abs().mean()
self.assertTrue(max_diff.item() < 1e-7)

@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
def test_singelton_head_dim_stride_ne_1(self, device):
query = torch.tensor([[[[1, 2]]]], dtype=torch.float16, device=device)
query = query.transpose(-1, -2)
key = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)
value = torch.tensor([[[[1]]]], dtype=torch.float16, device=device)

with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
scaled_dot_product_attention(query, key, value)

@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
Expand Down