Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require less alignment for attn bias (#114173) #114837

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
38 changes: 14 additions & 24 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,14 @@ c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& att
// We apply this function to the top level SDPA so that
// if padding is done it will be tracked for backward automatically

template <int alignment>
bool is_aligned(const SymInt& size){
return size % alignment == 0;
template<int alignment>
bool aligned_tensor(const at::Tensor& tensor){
for(const auto i : c10::irange(tensor.dim() - 1)){
if(tensor.sym_stride(i) % alignment != 0){
return false;
}
}
return tensor.sym_stride(-1) == 1;
}

template <int alignment>
Expand All @@ -608,31 +613,16 @@ at::Tensor preprocess_mask(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value) {
constexpr int mem_eff_alignment = 16;
// Expand to 4d case
at::Tensor attn_mask = mask.expand_symint(
constexpr int mem_eff_alignment = 8;
at::Tensor result_mask = mask;
if (!aligned_tensor<mem_eff_alignment>(mask)) {
result_mask = pad_bias<mem_eff_alignment>(mask);
}
return result_mask.expand_symint(
{query.sym_size(0),
query.sym_size(1),
query.sym_size(2),
key.sym_size(2)});

bool aligned_last_dim = is_aligned<mem_eff_alignment>(attn_mask.sym_size(-1));
// Apply pad_bias and store the result in attn_mask
if (!aligned_last_dim) {
return pad_bias<mem_eff_alignment>(attn_mask);
}
// Check and make the tensor contiguous if needed
auto needs_contig = [](const c10::SymInt& stride) {
return (stride % 16 != 0) || (stride == 0);
};
if (needs_contig(attn_mask.sym_stride(0)) ||
needs_contig(attn_mask.sym_stride(1)) ||
needs_contig(attn_mask.sym_stride(2)) ||
needs_contig(attn_mask.sym_stride(3))) {
return attn_mask.contiguous();
}

return attn_mask;
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel {
"value is not correctly aligned (strideH)");
TORCH_CHECK(
p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
"query is not correctly aligned (strideB)");
"query is not correctly aligned (strideB).");
TORCH_CHECK(
p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
"key is not correctly aligned (strideB)");
Expand All @@ -1216,13 +1216,19 @@ struct AttentionBackwardKernel {
if (p.bias_ptr) {
TORCH_CHECK(
p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideB)");
"attn_bias is not correctly aligned (strideB). ",
"attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
"multiple of ", kMinimumAlignment, ".");
TORCH_CHECK(
p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideH)");
"attn_bias is not correctly aligned (strideH) ."
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
"multiple of ", kMinimumAlignment, ".");
TORCH_CHECK(
p.bias_strideM % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideM)");
p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideM). "
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
"multiple of ", kMinimumAlignment, ".");
}
if (p.grad_bias_ptr) {
TORCH_CHECK(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,13 +578,19 @@ struct AttentionKernel {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
TORCH_CHECK(
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideB)");
"attn_bias is not correctly aligned (strideB). ",
"attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
"multiple of ", kAlignmentQ, ".");
TORCH_CHECK(
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideH)");
"attn_bias is not correctly aligned (strideH). "
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
"multiple of ", kAlignmentQ, ".");
TORCH_CHECK(
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideM). "
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
"multiple of ", kAlignmentQ, ".");
}
TORCH_CHECK(
p.q_strideM % kAlignmentQ == 0,
Expand Down
46 changes: 46 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6441,6 +6441,52 @@ def fn_or(x, y):
(torch.randn(32), torch.randn(32)),
)

@requires_cuda()
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA,
Copy link
Contributor

@atalman atalman Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On main PR this is called: PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, why on cherry pick we name this as : PLATFORM_SUPPORTS_FUSED_SDPA ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This flag was added inbetween releases and is not in this cherry pick. That being said PLATFORM_SUPPORTS_FUSED_SDPA is equivalent to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION in 2.1 release

"Does not support mem_eff_attention",
)
@skipIfRocm
def test_sdpa_unaligned_mask(self):
def foo(
arg0_1: "f32[8, 8, 16, 16]",
arg1_1: "f32[8, 8, 15, 16]",
arg2_1: "f32[8, 8, 15, 16]",
arg3_1: "f32[1, 1, 16, 15]",
):
constant_pad_nd: "f32[1, 1, 16, 16]" = (
torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
)
arg3_1 = None
slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
constant_pad_nd, -1, 0, 15
)
constant_pad_nd = None
expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
slice_1, [8, 8, 16, 15]
)
slice_1 = None
_scaled_dot_product_efficient_attention = (
torch.ops.aten._scaled_dot_product_efficient_attention.default(
arg0_1, arg1_1, arg2_1, expand, False
)
)
arg0_1 = arg1_1 = arg2_1 = expand = None
getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0]
_scaled_dot_product_efficient_attention = None
return (getitem,)

query = torch.rand(8, 8, 16, 16, device="cuda")
key = torch.rand(8, 8, 15, 16, device="cuda")
value = torch.rand(8, 8, 15, 16, device="cuda")
bias = torch.rand(1, 1, 16, 15, device="cuda")
self.common(
foo,
(query, key, value, bias),
atol=0.02,
rtol=1e4,
)

@skipIfRocm
def test_conv_with_as_strided(self):
class Model(nn.Module):
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ def run(*ex, **kwargs):
"test_zero_dim_reductions_dynamic_shapes": TestFailure(
("cpu", "cuda"), is_skip=True
),
"test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
#
# The following tests do not support dynamic shapes yet:
#
Expand Down
19 changes: 18 additions & 1 deletion test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,6 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward()


@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
Expand Down Expand Up @@ -1801,6 +1800,24 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_

self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
def test_mem_eff_attention_non_contig_mask_bug(self, device):
dtype = torch.float32
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
batch, num_heads, head_dim = 1, 16, 128
seq_len_q, seq_len_kv = 1, 16
query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2)
kv_shape = (batch, seq_len_kv, head_dim)
key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1)
key = key.expand(-1, num_heads, -1, -1)
value = value.expand(-1, num_heads, -1, -1)
mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask)
out_no_mask = F.scaled_dot_product_attention(query, key, value, None)
max_diff = (out - out_no_mask).abs().mean()
assert max_diff.item() < 1e-9

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
Expand Down
89 changes: 85 additions & 4 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1866,10 +1866,91 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
make_fallback(aten.randperm)
make_fallback(aten._scaled_dot_product_efficient_attention)
make_fallback(aten._scaled_dot_product_efficient_attention_backward)
make_fallback(aten._scaled_dot_product_flash_attention)
make_fallback(aten._scaled_dot_product_flash_attention_backward)


def sdpa_constraint(fx_node, *args, **kwargs):
# sdpa requires dense last dimension
def apply_constraint(arg, fx_arg):
if not isinstance(arg, ir.IRNode):
return arg

meta_val = fx_arg.meta["val"]
if not meta_val.is_cuda:
return arg

stride_order = ir.get_stride_order(meta_val.stride())
if stride_order and stride_order[-1] != 0:
# contiguous stride order
stride_order = list(reversed(range(len(arg.get_size()))))

# This is the minimum alignment required by SDPA kernels for attention_bias.
# This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
ALIGNMENT = 8

is_backward = fx_node.target in (
aten._scaled_dot_product_efficient_attention_backward.default,
aten._scaled_dot_product_flash_attention_backward.default,
)

def is_aligned(x):
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0

assert isinstance(arg, TensorBox)

# This correctly handles the forward case:
if isinstance(arg.data, (ir.SliceView, ir.ExpandView)):
if not is_aligned(arg):
# input is padded, requiring_stride_order will unwrap the view and unpad.
# Would be nice to be able to require certain padding from inductor ir, nyi
if is_aligned(arg.unwrap_view()):
return arg

def is_aligned_backward(x):
aligned_strides = all(
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
for i in range(len(x.get_stride()) - 1)
)
return (
V.graph.sizevars.size_hint(x.get_stride()[-1])
) == 1 and aligned_strides

if (
isinstance(arg.data, ir.StorageBox)
and arg.data.is_input_buffer()
and is_backward
):
if len(arg.data.get_size()) == 4 and is_aligned_backward(arg):
return arg

return ir.ExternKernel.require_stride_order(arg, stride_order)

args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs


make_fallback(
aten._scaled_dot_product_efficient_attention,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_efficient_attention_backward,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_flash_attention,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_flash_attention_backward,
sdpa_constraint,
warn=False,
)
make_fallback(aten.sort)
make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
Expand Down
12 changes: 7 additions & 5 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4985,12 +4985,14 @@ def meta__scaled_dot_product_efficient_backward(
)
grad_bias = None
if attn_bias is not None and grad_input_mask[3]:
grad_bias = torch.empty_strided(
attn_bias.size(),
attn_bias.stride(),
dtype=attn_bias.dtype,
device=attn_bias.device,
lastDim = attn_bias.size(-1)
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
new_sizes = list(attn_bias.size())
new_sizes[-1] = lastDimAligned
grad_bias = torch.empty(
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
)
grad_bias = grad_bias[..., :lastDim]

return grad_q, grad_k, grad_v, grad_bias

Expand Down