From abf144e113f6095d559b48fe513648a3707790c9 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Wed, 21 Feb 2024 08:13:34 -0800 Subject: [PATCH] [xla:gpu] cudnn_fused_mha: do not require Compute Capability's minor == 0 So that we can run this on GPUs with Compute Capability >= 8.0 regardless of the minor number, e.g. A6000 (8.6). While at it, simplify the expression. PiperOrigin-RevId: 608999459 --- .../service/gpu/cudnn_fused_mha_rewriter.cc | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 891a15280be601..4fb1248a4d6c9c 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -309,17 +309,15 @@ bool IsComputeCapabilityAndCudnnSupported( stream_executor::CudaComputeCapability cc, stream_executor::dnn::VersionInfo cudnn_version, stream_executor::dnn::VersionInfo supported_cudnn_version) { - if (!((cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0) && - (cudnn_version >= supported_cudnn_version))) { - VLOG(2) << absl::StrFormat( - "CudnnFusedMHARewriter did not run. Unsupported compute " - "capability(==8.0) or cudnn version(>=%d.%d.%d)", - supported_cudnn_version.major_version(), - supported_cudnn_version.minor_version(), - supported_cudnn_version.patch()); - return false; + if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) { + return true; } - return true; + VLOG(2) << absl::StrFormat( + "CudnnFusedMHARewriter did not run. Unsupported compute " + "capability(%s; should be >= 8.0) or cudnn version(%s; should be >= %s)", + cc.ToString(), cudnn_version.ToString(), + supported_cudnn_version.ToString()); + return false; } bool IsSupportedPrimitiveType(const HloInstruction* bmm) {