Skip to content

Commit

Permalink
[xla:gpu] cudnn_fused_mha: do not require Compute Capability's minor …
Browse files Browse the repository at this point in the history
…== 0

So that we can run this on GPUs with Compute Capability >= 8.0 regardless
of the minor number, e.g. A6000 (8.6).

While at it, simplify the expression.

PiperOrigin-RevId: 608999459
  • Loading branch information
cota authored and tensorflower-gardener committed Feb 21, 2024
1 parent 6928b6f commit abf144e
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,15 @@ bool IsComputeCapabilityAndCudnnSupported(
stream_executor::CudaComputeCapability cc,
stream_executor::dnn::VersionInfo cudnn_version,
stream_executor::dnn::VersionInfo supported_cudnn_version) {
if (!((cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0) &&
(cudnn_version >= supported_cudnn_version))) {
VLOG(2) << absl::StrFormat(
"CudnnFusedMHARewriter did not run. Unsupported compute "
"capability(==8.0) or cudnn version(>=%d.%d.%d)",
supported_cudnn_version.major_version(),
supported_cudnn_version.minor_version(),
supported_cudnn_version.patch());
return false;
if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) {
return true;
}
return true;
VLOG(2) << absl::StrFormat(
"CudnnFusedMHARewriter did not run. Unsupported compute "
"capability(%s; should be >= 8.0) or cudnn version(%s; should be >= %s)",
cc.ToString(), cudnn_version.ToString(),
supported_cudnn_version.ToString());
return false;
}

bool IsSupportedPrimitiveType(const HloInstruction* bmm) {
Expand Down

0 comments on commit abf144e

Please sign in to comment.