Skip to content

Commit

Permalink
Adding Backward Support for NestedTensors and FlashAttention (#97485)
Browse files Browse the repository at this point in the history
# Summary
<!--
copilot:summary
-->
### <samp>馃 Generated by Copilot at 318764f</samp>

This pull request implements the CUDA backend of the SDPA kernel for nested tensors, which enables efficient transformer models with variable-length sequences. It adds a new dispatch key, a backward function, a unit test, and some helper functions for the kernel. It modifies `test/test_transformers.py`, `aten/src/ATen/native/native_functions.yaml`, `aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctionsBackward.cpp`, and `aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.h`.

<!--
copilot:poem
-->
### <samp>馃 Generated by Copilot at ed4a773</samp>

> _Fused kernels of doom, unleash the flash attention_
> _Nested tensors on fire, reshape and pad with caution_
> _Backward pass of power, dispatch the CUDA key_
> _Test the gradients of hell, warn the user if they disagree_

Pull Request resolved: #97485
Approved by: https://github.com/jbschlosser
  • Loading branch information
drisspg authored and pytorchmergebot committed Oct 10, 2023
1 parent 77e5f5d commit 5183760
Show file tree
Hide file tree
Showing 7 changed files with 428 additions and 69 deletions.
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14366,6 +14366,7 @@
dispatch:
CPU: _scaled_dot_product_flash_attention_backward_cpu
CUDA: _scaled_dot_product_flash_attention_backward_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_backward_nested

- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
dispatch:
Expand Down
21 changes: 21 additions & 0 deletions aten/src/ATen/native/nested/NestedTensorTransformerUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,27 @@ sdpa_nested_preprocessing(
const Tensor& key,
const Tensor& value);

/**
* This function will take nested query, key, and value, grad_out, and out
* and will preprocess it in order to run with either
* the flash-attention or efficient-attention kernels backwards.
* We use both functions to avoid having to do the same preprocessing
* for cumulative_sequence_length_q and cumulative_sequence_length_kv
* @return A tuple containing all the necessary data for running the fused
* kernels
*/
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
sdpa_nested_preprocessing_backward(
const at::Tensor& grad_out_,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_kv,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_kv);

} // namespace preprocessing
} // namespace native
} // namespace at
Original file line number Diff line number Diff line change
Expand Up @@ -324,5 +324,68 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
}

std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attention_backward_nested(
const at::Tensor& grad_out_,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const at::Tensor& logsumexp,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
double dropout_p,
bool is_causal,
const at::Tensor& philox_seed,
const at::Tensor& philox_offset,
c10::optional<double> scale){
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
}
Tensor grad_out_buffer_reshaped, query_buffer_reshaped, key_buffer_reshaped,
value_buffer_reshaped, output_buffer_reshaped;
std::tie(
grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped) =
preprocessing::sdpa_nested_preprocessing_backward(
grad_out_,
query,
key,
value,
out,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k);

Tensor grad_q, grad_k, grad_v;
std::tie(grad_q, grad_k, grad_v) = at::_flash_attention_backward(
grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped,
logsumexp,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
is_causal,
philox_seed,
philox_offset,
scale);

grad_q = wrap_buffer(grad_q.view(-1), query.transpose(1,2)._nested_tensor_size()).transpose(1,2);
grad_k = wrap_buffer(grad_k.view(-1), key.transpose(1,2)._nested_tensor_size()).transpose(1,2);
grad_v = wrap_buffer(grad_v.view(-1), value.transpose(1,2)._nested_tensor_size()).transpose(1,2);

return std::make_tuple(grad_q, grad_k, grad_v);
}

} // namespace native
} // namespace at
197 changes: 155 additions & 42 deletions aten/src/ATen/native/nested/cuda/NestedTensorTransformerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,55 +85,78 @@ int64_t get_nnz(Tensor nestedtensor) {
* @return A boolean indicating of contiguous needs to be called for input
*/
bool is_safe_to_get_storage_as_tensor(const NestedTensorImpl* tensor) {
const int64_t *tensor_offsets_ptr = tensor->get_storage_offsets().data_ptr<int64_t>();
const Tensor& tensor_sizes = tensor->get_nested_sizes();
const Tensor& tensor_strides = tensor->get_nested_strides();

const int64_t n_tensors = tensor_strides.size(0);
const int64_t n_dims = tensor_strides.size(1);

if (n_tensors <= 1) {
return true;
}

int64_t* previous_tensor_stride = tensor_strides.data_ptr<int64_t>();
// Check initially that they are in strictly descending order
for (int i{1}; i < n_dims; i++) {
if (previous_tensor_stride[i - 1] <= previous_tensor_stride[i]) {
return false;
const int64_t* tensor_offsets_ptr =
tensor->get_storage_offsets().data_ptr<int64_t>();
const Tensor& tensor_sizes = tensor->get_nested_sizes();
const Tensor& tensor_strides = tensor->get_nested_strides();

const int64_t n_tensors = tensor_strides.size(0);
constexpr int64_t n_dims = 3;
// This is safe since head_dim is assured to be consistent
const int64_t num_heads = tensor -> opt_size(2).value();
const int64_t tensor_stride_0 = tensor_strides.stride(0);

if (n_tensors <= 1) {
return true;
}
}
// Check that each tensor i in the nested tensor has the same strides
auto tensor_stride_0 = tensor_strides.stride(0);

for (int i{1}; i < n_tensors; i++) {
for (const int64_t j : c10::irange(n_dims)) {
if (previous_tensor_stride[j] !=
previous_tensor_stride[i * tensor_stride_0 + j]) {
int64_t* previous_tensor_stride = tensor_strides.data_ptr<int64_t>();

// Check initially that the first tensor's strides
// are in strictly descending order
// NOTE: If num_heads is equal to 1 then we skip stride[0]
// Why you may ask? This is because we if n_heads == 1 then
// then as long as the last stride == 1 it does not matter
// what the strides are for the other dimensions.
//
if (num_heads == 1) {
if (previous_tensor_stride[0] <= previous_tensor_stride[2]) {
// This would mean that the last stride is greater than the seq_len
// stride
return false;
}
} else {
for (int i{1}; i < n_dims; i++) {
if (previous_tensor_stride[i - 1] <= previous_tensor_stride[i]) {
return false;
}
}
// Check that each tensor i in the nested tensor has the same strides
for (int i{1}; i < n_tensors; i++) {
for (const int64_t j : c10::irange(n_dims)) {
if (previous_tensor_stride[j] !=
previous_tensor_stride[i * tensor_stride_0 + j]) {
return false;
}
}
}
}
}
// Check the offsets are a constant multiple from the previous numels
const int64_t* tensor_size_ptr = tensor_sizes.data_ptr<int64_t>();
const int64_t* tensor_stride_ptr = tensor_strides.data_ptr<int64_t>();

int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]);
TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!");

int64_t offset_constant = (tensor_offsets_ptr[1] - tensor_offsets_ptr[0]) / numel_0;
for (int64_t i = 2; i < n_tensors; i++) {
// TODO: When 0 seq_len nested tensors are allowed we need to guard against this
int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] * tensor_stride_ptr[(i - 1) * tensor_stride_0];
TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!");
int64_t current_offset_constant = (tensor_offsets_ptr[i] - tensor_offsets_ptr[i - 1]) / previous_numel;
if (current_offset_constant != offset_constant) {
return false;

// Check the offsets are a constant multiple from the previous numels
const int64_t* tensor_size_ptr = tensor_sizes.data_ptr<int64_t>();
const int64_t* tensor_stride_ptr = tensor_strides.data_ptr<int64_t>();

int64_t numel_0 = (tensor_size_ptr[0] * tensor_stride_ptr[0]);
TORCH_INTERNAL_ASSERT(numel_0 > 0, "numels must be positive!");

int64_t offset_constant =
(tensor_offsets_ptr[1] - tensor_offsets_ptr[0]) / numel_0;
for (int64_t i = 2; i < n_tensors; i++) {
// TODO: When 0 seq_len nested tensors are allowed we need to guard
// against this
int64_t previous_numel = tensor_size_ptr[(i - 1) * tensor_stride_0] *
tensor_stride_ptr[(i - 1) * tensor_stride_0];
TORCH_INTERNAL_ASSERT(previous_numel > 0, "numels must be positive!");
int64_t current_offset_constant =
(tensor_offsets_ptr[i] - tensor_offsets_ptr[i - 1]) / previous_numel;
if (current_offset_constant != offset_constant) {
return false;
}
}
// Congrats you made it!
return true;
}
// Congrats you made it!
return true;
}

/**
* Process an individual NestedTensor to reshape and view as a DenseTensor
* Generally the approach for q, k, v is to
Expand Down Expand Up @@ -449,6 +472,96 @@ sdpa_nested_preprocessing(
output_shape);
}

std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
sdpa_nested_preprocessing_backward(
const at::Tensor& grad_out_,
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value,
const at::Tensor& out,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_kv,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_kv) {
const int64_t q_batch_size = query.size(0);
const int64_t k_batch_size = key.size(0);

const int64_t v_batch_size = value.size(0);

const int64_t q_num_heads = query.size(1);
const int64_t k_num_heads = key.size(1);
const int64_t v_num_heads = value.size(1);

if (!(q_batch_size == k_batch_size && q_batch_size == v_batch_size) ||
!(q_num_heads == k_num_heads && k_num_heads == v_num_heads)) {
TORCH_CHECK(false, "Broadcasted NestedTensor inputs is currently not supported for backwards.");
}

const int64_t num_heads = query.size(1);
const int64_t head_dim_qk = query.size(3);
const int64_t head_dim_v = value.size(3);

Tensor q_t = query.transpose(1, 2);
Tensor k_t = key.transpose(1, 2);
Tensor v_t = value.transpose(1, 2);
Tensor grad_out_t = grad_out_.transpose(1, 2);
Tensor out_t = out.transpose(1, 2);

const int64_t Nnz_q = get_nnz(q_t);
const int64_t Nnz_kv = get_nnz(k_t);

Tensor query_buffer_reshaped;
Tensor key_buffer_reshaped;
Tensor value_buffer_reshaped;
Tensor grad_out_buffer_reshaped;
Tensor output_buffer_reshaped;

const auto* query_impl = get_nested_tensor_impl(q_t);
const auto* key_impl = get_nested_tensor_impl(k_t);
const auto* value_impl = get_nested_tensor_impl(v_t);
const auto* grad_out_impl = get_nested_tensor_impl(grad_out_t);
const auto* out_impl = get_nested_tensor_impl(out_t);

// If the physical layout of the NestedTensor's storage
// is not: batch, {seq_len}, num_heads, head_dim then we need
// to call contiguous
if (!q_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(query_impl)) {
q_t = q_t.contiguous();
}
if (!k_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(key_impl)) {
k_t = k_t.contiguous();
}
if (!v_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(value_impl)) {
v_t = v_t.contiguous();
}
if (!grad_out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(grad_out_impl)) {
grad_out_t = grad_out_t.contiguous();
}
if (!out_t.is_contiguous() && !is_safe_to_get_storage_as_tensor(out_impl)) {
out_t = out_t.contiguous();
}

query_buffer_reshaped = view_as_dense(q_t, Nnz_q, num_heads, head_dim_qk);
key_buffer_reshaped = view_as_dense(k_t, Nnz_kv, num_heads, head_dim_qk);
value_buffer_reshaped = view_as_dense(v_t, Nnz_kv, num_heads, head_dim_v);

grad_out_buffer_reshaped =
view_as_dense(grad_out_t, Nnz_q, num_heads, head_dim_v);
output_buffer_reshaped = view_as_dense(out_t, Nnz_q, num_heads, head_dim_v);

auto output_shape = get_nested_sizes(q_t).clone();
if (head_dim_v != head_dim_qk) {
output_shape.select(1, -1).fill_(head_dim_v);
}

return std::make_tuple(
grad_out_buffer_reshaped,
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
output_buffer_reshaped);
}

} // namespace preprocessing
} // namespace native
} // namespace at
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// in order to determine whether we are using varlen or dense forward
if (cumulative_sequence_length_q.defined()) {
// Varlen forward
TORCH_CHECK(false, "Dont go down this path yet");
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
contiguous_grad_out,
query,
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/transformers/sdp_utils_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,13 @@ inline bool check_for_seq_len_0_nested_tensor(sdp_params params, bool debug) {
q_num_heads == k_num_heads && q_num_heads == v_num_heads;

if (!same_num_heads) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
return try_broadcast_param_size(
q_num_heads, k_num_heads, v_num_heads, "num heads ", debug);
}
Expand Down Expand Up @@ -316,6 +323,13 @@ inline bool check_batch_size_and_num_heads(sdp_params params, bool debug) {
if (has_nested_input) {
bool broadcastable_batch_size = true;
if (!same_batch_size) {
if (input_requires_grad(params)){
if (debug) {
TORCH_WARN(
"Both fused kernels do not support training with broadcasted NT inputs.");
}
return false;
}
// try to broadcast batchsize
broadcastable_batch_size = try_broadcast_param_size(
q_batch_size, k_batch_size, v_batch_size, "batch size ", debug);
Expand Down
Loading

0 comments on commit 5183760

Please sign in to comment.