diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index 988d79fe2dbae..df40dec2d3d8a 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -590,9 +590,14 @@ c10::optional convert_boolean_attn_mask(const c10::optional& att // We apply this function to the top level SDPA so that // if padding is done it will be tracked for backward automatically -template -bool is_aligned(const SymInt& size){ - return size % alignment == 0; +template +bool aligned_tensor(const at::Tensor& tensor){ + for(const auto i : c10::irange(tensor.dim() - 1)){ + if(tensor.sym_stride(i) % alignment != 0){ + return false; + } + } + return tensor.sym_stride(-1) == 1; } template @@ -608,31 +613,16 @@ at::Tensor preprocess_mask( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value) { - constexpr int mem_eff_alignment = 16; - // Expand to 4d case - at::Tensor attn_mask = mask.expand_symint( + constexpr int mem_eff_alignment = 8; + at::Tensor result_mask = mask; + if (!aligned_tensor(mask)) { + result_mask = pad_bias(mask); + } + return result_mask.expand_symint( {query.sym_size(0), query.sym_size(1), query.sym_size(2), key.sym_size(2)}); - - bool aligned_last_dim = is_aligned(attn_mask.sym_size(-1)); - // Apply pad_bias and store the result in attn_mask - if (!aligned_last_dim) { - return pad_bias(attn_mask); - } - // Check and make the tensor contiguous if needed - auto needs_contig = [](const c10::SymInt& stride) { - return (stride % 16 != 0) || (stride == 0); - }; - if (needs_contig(attn_mask.sym_stride(0)) || - needs_contig(attn_mask.sym_stride(1)) || - needs_contig(attn_mask.sym_stride(2)) || - needs_contig(attn_mask.sym_stride(3))) { - return attn_mask.contiguous(); - } - - return attn_mask; } } // namespace diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index 0a5bb1db0433a..ec5a4a8a6ef5f 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned (strideH)"); TORCH_CHECK( p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, - "query is not correctly aligned (strideB)"); + "query is not correctly aligned (strideB)."); TORCH_CHECK( p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); @@ -1216,13 +1216,19 @@ struct AttentionBackwardKernel { if (p.bias_ptr) { TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideB)"); + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kMinimumAlignment, "."); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideH)"); + "attn_bias is not correctly aligned (strideH) ." + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kMinimumAlignment, "."); TORCH_CHECK( - p.bias_strideM % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideM)"); + p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", + "multiple of ", kMinimumAlignment, "."); } if (p.grad_bias_ptr) { TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 3a8189af09c4d..2e81480086d9d 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -578,13 +578,19 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideB)"); + "attn_bias is not correctly aligned (strideB). ", + "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " + "multiple of ", kAlignmentQ, "."); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideH)"); + "attn_bias is not correctly aligned (strideH). " + "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " + "multiple of ", kAlignmentQ, "."); TORCH_CHECK( - p.bias_strideM % kAlignmentQ == 0, - "attn_bias is not correctly aligned"); + p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideM). " + "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " + "multiple of ", kAlignmentQ, "."); } TORCH_CHECK( p.q_strideM % kAlignmentQ == 0, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e6f1ea5a0561a..f326a3d2b5c31 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6441,6 +6441,52 @@ def fn_or(x, y): (torch.randn(32), torch.randn(32)), ) + @requires_cuda() + @unittest.skipIf( + not PLATFORM_SUPPORTS_FUSED_SDPA, + "Does not support mem_eff_attention", + ) + @skipIfRocm + def test_sdpa_unaligned_mask(self): + def foo( + arg0_1: "f32[8, 8, 16, 16]", + arg1_1: "f32[8, 8, 15, 16]", + arg2_1: "f32[8, 8, 15, 16]", + arg3_1: "f32[1, 1, 16, 15]", + ): + constant_pad_nd: "f32[1, 1, 16, 16]" = ( + torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0) + ) + arg3_1 = None + slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor( + constant_pad_nd, -1, 0, 15 + ) + constant_pad_nd = None + expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default( + slice_1, [8, 8, 16, 15] + ) + slice_1 = None + _scaled_dot_product_efficient_attention = ( + torch.ops.aten._scaled_dot_product_efficient_attention.default( + arg0_1, arg1_1, arg2_1, expand, False + ) + ) + arg0_1 = arg1_1 = arg2_1 = expand = None + getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0] + _scaled_dot_product_efficient_attention = None + return (getitem,) + + query = torch.rand(8, 8, 16, 16, device="cuda") + key = torch.rand(8, 8, 15, 16, device="cuda") + value = torch.rand(8, 8, 15, 16, device="cuda") + bias = torch.rand(1, 1, 16, 15, device="cuda") + self.common( + foo, + (query, key, value, bias), + atol=0.02, + rtol=1e4, + ) + @skipIfRocm def test_conv_with_as_strided(self): class Model(nn.Module): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 21ff02aced7ce..e6d3dd27b996b 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -259,6 +259,7 @@ def run(*ex, **kwargs): "test_zero_dim_reductions_dynamic_shapes": TestFailure( ("cpu", "cuda"), is_skip=True ), + "test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/test/test_transformers.py b/test/test_transformers.py index 1314755450feb..18ed523558e2e 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1760,7 +1760,6 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype): out = F.scaled_dot_product_attention(query, key, value, mask) out.sum().backward() - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) @@ -1801,6 +1800,24 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_ self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") + def test_mem_eff_attention_non_contig_mask_bug(self, device): + dtype = torch.float32 + make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) + batch, num_heads, head_dim = 1, 16, 128 + seq_len_q, seq_len_kv = 1, 16 + query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2) + kv_shape = (batch, seq_len_kv, head_dim) + key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1) + key = key.expand(-1, num_heads, -1, -1) + value = value.expand(-1, num_heads, -1, -1) + mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool) + with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + out = F.scaled_dot_product_attention(query, key, value, mask) + out_no_mask = F.scaled_dot_product_attention(query, key, value, None) + max_diff = (out - out_no_mask).abs().mean() + assert max_diff.item() < 1e-9 + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 13f8a0dfc1a80..a8d58cda0d7ef 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1866,10 +1866,91 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) -make_fallback(aten._scaled_dot_product_efficient_attention) -make_fallback(aten._scaled_dot_product_efficient_attention_backward) -make_fallback(aten._scaled_dot_product_flash_attention) -make_fallback(aten._scaled_dot_product_flash_attention_backward) + + +def sdpa_constraint(fx_node, *args, **kwargs): + # sdpa requires dense last dimension + def apply_constraint(arg, fx_arg): + if not isinstance(arg, ir.IRNode): + return arg + + meta_val = fx_arg.meta["val"] + if not meta_val.is_cuda: + return arg + + stride_order = ir.get_stride_order(meta_val.stride()) + if stride_order and stride_order[-1] != 0: + # contiguous stride order + stride_order = list(reversed(range(len(arg.get_size())))) + + # This is the minimum alignment required by SDPA kernels for attention_bias. + # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask + ALIGNMENT = 8 + + is_backward = fx_node.target in ( + aten._scaled_dot_product_efficient_attention_backward.default, + aten._scaled_dot_product_flash_attention_backward.default, + ) + + def is_aligned(x): + return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 + + assert isinstance(arg, TensorBox) + + # This correctly handles the forward case: + if isinstance(arg.data, (ir.SliceView, ir.ExpandView)): + if not is_aligned(arg): + # input is padded, requiring_stride_order will unwrap the view and unpad. + # Would be nice to be able to require certain padding from inductor ir, nyi + if is_aligned(arg.unwrap_view()): + return arg + + def is_aligned_backward(x): + aligned_strides = all( + (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 + for i in range(len(x.get_stride()) - 1) + ) + return ( + V.graph.sizevars.size_hint(x.get_stride()[-1]) + ) == 1 and aligned_strides + + if ( + isinstance(arg.data, ir.StorageBox) + and arg.data.is_input_buffer() + and is_backward + ): + if len(arg.data.get_size()) == 4 and is_aligned_backward(arg): + return arg + + return ir.ExternKernel.require_stride_order(arg, stride_order) + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + +make_fallback( + aten._scaled_dot_product_efficient_attention, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_efficient_attention_backward, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention, + sdpa_constraint, + warn=False, +) +make_fallback( + aten._scaled_dot_product_flash_attention_backward, + sdpa_constraint, + warn=False, +) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 56593916a4a15..0dee68f46c243 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4985,12 +4985,14 @@ def meta__scaled_dot_product_efficient_backward( ) grad_bias = None if attn_bias is not None and grad_input_mask[3]: - grad_bias = torch.empty_strided( - attn_bias.size(), - attn_bias.stride(), - dtype=attn_bias.dtype, - device=attn_bias.device, + lastDim = attn_bias.size(-1) + lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 + new_sizes = list(attn_bias.size()) + new_sizes[-1] = lastDimAligned + grad_bias = torch.empty( + new_sizes, dtype=attn_bias.dtype, device=attn_bias.device ) + grad_bias = grad_bias[..., :lastDim] return grad_q, grad_k, grad_v, grad_bias