Skip to content

Commit

Permalink
Change flash attention outputs to be SymInt instead of int (#110533)
Browse files Browse the repository at this point in the history
Fixes #110322

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #110533
Approved by: https://github.com/albanD
  • Loading branch information
ezyang authored and pytorchmergebot committed Oct 5, 2023
1 parent f1d8113 commit 6a974be
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 15 deletions.
8 changes: 4 additions & 4 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -14349,14 +14349,14 @@
variants: function
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CPU: _scaled_dot_product_flash_attention_cpu
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
- func: _scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)
device_check: NoCheck
variants: function
dispatch:
Expand All @@ -14375,13 +14375,13 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded

- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt? max_q, SymInt? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward
tags: nondeterministic_seeded

- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:
Expand Down
Expand Up @@ -220,8 +220,8 @@ std::tuple<
Tensor,
Tensor,
Tensor,
int64_t,
int64_t,
c10::SymInt,
c10::SymInt,
Tensor,
Tensor,
Tensor>
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/transformers/attention.cpp
Expand Up @@ -744,8 +744,8 @@ std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
int64_t,
int64_t,
c10::SymInt,
c10::SymInt,
at::Tensor,
at::Tensor,
at::Tensor>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/transformers/cuda/attention.cu
Expand Up @@ -668,7 +668,7 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
}
return std::make_tuple(std::move(proj), std::move(qkt));
}
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt, Tensor, Tensor, Tensor> _scaled_dot_product_flash_attention_cuda(
const Tensor& query,
const Tensor& key,
const Tensor& value,
Expand Down
48 changes: 48 additions & 0 deletions test/inductor/test_cuda_repro.py
Expand Up @@ -5,13 +5,16 @@

import torch
import torch._dynamo.config as dynamo_config
import torch.backends.cuda
import torch.nn.functional as F
from torch import nn
from torch._dynamo.debug_utils import same_two_models
from torch._dynamo.testing import rand_strided
from torch._dynamo.utils import same
from torch._inductor import config
from torch._inductor.compile_fx import compile_fx_inner
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_utils import (
DeterministicGuard,
freeze_rng_state,
Expand Down Expand Up @@ -982,6 +985,51 @@ def fn(x, y, z):

self.assertEqual(ref, res)

@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
)
def test_flash_attention_dynamic(self):
class Model(nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self.q = nn.Linear(1024, 1024)
self.k = nn.Linear(1024, 1024)
self.v = nn.Linear(1024, 1024)

def forward(self, x):
batch_size, seq_len, _ = x.size()

queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)

attn = F.scaled_dot_product_attention(
queries,
keys,
values,
)

return attn

cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")

model = Model().cuda().half()
model = torch.compile(model, backend=cnts, dynamic=True)

with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
):
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)

out1 = model(input1)
out2 = model(input2)
out3 = model(input3)

self.assertEqual(cnts.frame_count, 1)

@config.patch({"triton.cudagraphs": True})
def test_index_put_no_fallback_cudagraph(self):
def fn(x, y, z):
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/derivatives.yaml
Expand Up @@ -2764,9 +2764,9 @@
output_differentiability: [True, False, False, False]
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)

- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
output_differentiability: [True, False, False, False, False, False, False, False, False]
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)

# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
# output_differentiability: [True, False, False, False, False, False, False, False]
Expand Down
2 changes: 1 addition & 1 deletion torch/_C/return_types.pyi.in
Expand Up @@ -16,7 +16,7 @@ from typing import (
Union,
)

from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor
from torch import contiguous_format, Generator, inf, memory_format, strided, Tensor, SymInt
from torch.types import (
_bool,
_device,
Expand Down
2 changes: 2 additions & 0 deletions torch/_inductor/ir.py
Expand Up @@ -3984,6 +3984,8 @@ def generate_output(output, indices):
)
elif isinstance(output, int):
return output
elif isinstance(output, torch.SymInt):
return output.node.expr
else:
assert (
output is None
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Expand Up @@ -228,8 +228,8 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
at::Tensor* ret3_tensor = new at::Tensor(std::move(r3));
*ret3 = tensor_pointer_to_tensor_handle(ret3_tensor);
}
*ret4 = r4;
*ret5 = r5;
*ret4 = r4.expect_int();
*ret5 = r5.expect_int();
at::Tensor* ret6_tensor = new at::Tensor(std::move(r6));
*ret6 = tensor_pointer_to_tensor_handle(ret6_tensor);
at::Tensor* ret7_tensor = new at::Tensor(std::move(r7));
Expand Down
2 changes: 1 addition & 1 deletion torchgen/api/python.py
Expand Up @@ -1129,7 +1129,7 @@ def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
"::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
"::std::vector<at::Tensor>",
# Needed for flash attention forw/backward
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t,int64_t,at::Tensor,at::Tensor,at::Tensor>",
"::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
"at::Scalar",
"bool",
"int64_t",
Expand Down

0 comments on commit 6a974be

Please sign in to comment.