Skip to content

Commit

Permalink
PR #11444: [XLA:GPU] disable mask in cuDNN attention
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#11444

1. cuDNN attention mask is not doing masking with -inf but multiply which is not correct. Hence disable patterns with mask.
2. Follow up PR to clean up the remaining mask related logic.
Copybara import of the project:

--
acf95b6cc7e1084026eaf87c0119ba3801ba8f8c by cjkkkk <ske@nvidia.com>:

disable mask

Merging this change closes #11444

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11444 from Cjkkkk:remove_mask acf95b6cc7e1084026eaf87c0119ba3801ba8f8c
PiperOrigin-RevId: 624057479
  • Loading branch information
Cjkkkk authored and tensorflower-gardener committed Apr 12, 2024
1 parent 8c5a9b0 commit 3b20da8
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2,480 deletions.
10 changes: 4 additions & 6 deletions third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,6 @@ absl::StatusOr<bool> IsFlashAttention(
bool is_flash_attention = is_seqlen_supported && is_hidden_dim_supported;
if (!is_flash_attention) return false;

// TODO(hebecker): The fMHA rewriter is triggering some miscompile when used
// with cuDNN 8.9.6+. So this is temporarily disabling all the capabilities
// added by 8.9.6 and beyond:
cudnn_version =
std::min(cudnn_version, stream_executor::dnn::VersionInfo(8, 9, 5));

// going backwards to check compatibility
if ((is_training && (s_q < 64 || s_kv < 64)) &&
!IsComputeCapabilityAndCudnnSupported(
Expand Down Expand Up @@ -1865,6 +1859,10 @@ absl::StatusOr<bool> CudnnFusedMHARewriter::Run(
if (!matched_result.has_match) {
continue;
}
// disable cuDNN mask input
if (matched_result.matched_mask) {
continue;
}
// We check the validity of bmms here before canonicalization so we don't
// modify the graph if mha fusion is not possible
// Relax 512 constraint if it is flash attention
Expand Down
Loading

0 comments on commit 3b20da8

Please sign in to comment.