Skip to content

Commit

Permalink
PR #11478: [XLA:GPU] add guards for flash attention graph with cuDNN …
Browse files Browse the repository at this point in the history
…>= 8.9.4

Imported from GitHub PR openxla/xla#11478

* Build XLA with cuDNN 8.6 will cause compilation error because flash attention graph will use `CudnnfMHAUid` which is only defined with cuDNN > 8.8.
* Add a guard for flash attention graph with at least cuDNN 8.9.4. So the current logic would be only compile FMHA with cuDNN > 8.8 and only compile flash attention with cuDNN > 8.9.4.
Copybara import of the project:

--
a1aa585f4e6ce42c7486336549447151cd5f7690 by cjkkkk <ske@nvidia.com>:

add guards for flash attention graph with at least 8.9.4

Merging this change closes #11478

PiperOrigin-RevId: 625425491
  • Loading branch information
Cjkkkk authored and tensorflower-gardener committed Apr 16, 2024
1 parent 8a675f0 commit 2194c88
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions third_party/xla/xla/stream_executor/cuda/cuda_dnn.cc
Expand Up @@ -6348,6 +6348,7 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
const std::optional<double> dropout_rate, const bool is_causal_mask) {
using cudnn_frontend::graph::Tensor_attributes;

#if CUDNN_VERSION >= 8904
if (VLOG_IS_ON(4)) {
VLOG(4) << "\n bmm1_lhs(q): " << q_descriptor.ToString()
<< "\n bmm1_rhs(k): " << k_descriptor.ToString()
Expand Down Expand Up @@ -6473,6 +6474,10 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionOperationGraph(
VLOG(4) << "\b flash attention operation graph: " << graph;
}
return cudnnGraph;
#else
return absl::UnimplementedError(
"Cudnn flash attention only supported with Cudnn >= 8.9.4");
#endif
}

absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionBackwardOperationGraph(
Expand All @@ -6487,6 +6492,7 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionBackwardOperationGraph(
std::optional<double> dropout_rate, std::optional<int64_t> seed,
double scale, bool use_dropout = false, bool use_mask = false,
bool use_bias = false, bool use_causal_mask = false) {
#if CUDNN_VERSION >= 8904
if (VLOG_IS_ON(4)) {
VLOG(4) << "\n bmm1_grad_gemm1_rhs(q): " << q_desc.ToString()
<< "\n bmm1_grad_gemm2_rhs(k): " << k_desc.ToString()
Expand Down Expand Up @@ -6643,6 +6649,10 @@ absl::StatusOr<CudnnGraph> GetCudnnFlashAttentionBackwardOperationGraph(
}

return cudnnGraph;
#else
return absl::UnimplementedError(
"Cudnn flash attention only supported with Cudnn >= 8.9.4");
#endif
}

absl::Status CudnnSupport::DoPrepareForConvolution(
Expand Down

0 comments on commit 2194c88

Please sign in to comment.