Skip to content

Commit

Permalink
[BE] Create grad check util (#126991)
Browse files Browse the repository at this point in the history
# Summary
Add small utility func for deciding if we shoudl compute LSE and update to also check for gradMode
Pull Request resolved: #126991
Approved by: https://github.com/cpuhrsch
  • Loading branch information
drisspg authored and pytorchmergebot committed May 24, 2024
1 parent 27594be commit cfb374d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
15 changes: 9 additions & 6 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/core/TensorBody.h>
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/TensorOperators.h>
Expand Down Expand Up @@ -608,6 +609,12 @@ int64_t handle_private_use(const Tensor& query_, const Tensor& key, const Tensor
return choice_int;
}

bool should_compute_logsumexp(const Tensor& query, const Tensor& key, const Tensor& value) {
const bool any_inputs_require_grad = query.requires_grad() || key.requires_grad() || value.requires_grad();
const bool gradmode_enabled = at::GradMode::is_enabled();
return any_inputs_require_grad && gradmode_enabled;
}

} // namespace

// Computes scaled dot product attention on query, key and value tensors, using
Expand Down Expand Up @@ -665,9 +672,7 @@ Tensor scaled_dot_product_attention(
std::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
switch (backend) {
case sdp::SDPBackend::cudnn_attention: {
bool compute_logsumexp =
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
auto out_lse_softmax = at::_scaled_dot_product_cudnn_attention(
query_, key, value, dropout_p, is_causal, compute_logsumexp, scale);
return std::get<0>(out_lse_softmax);
Expand All @@ -689,9 +694,7 @@ Tensor scaled_dot_product_attention(
query_, key, value, dropout_p, is_causal, attn_mask, scale));
}
case sdp::SDPBackend::efficient_attention: {
bool compute_logsumexp =
(query_.requires_grad() || key.requires_grad() ||
value.requires_grad());
bool compute_logsumexp = should_compute_logsumexp(query_, key, value);
if (attn_mask.has_value()) {
attn_mask.value() = preprocess_mask(attn_mask.value(), query_, key, value);;
}
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>
#include <cmath>
#include <functional>

#if USE_ROCM
#include <aotriton/flash.h>
Expand Down

0 comments on commit cfb374d

Please sign in to comment.