diff --git a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc index 891a15280be601..4fb1248a4d6c9c 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fused_mha_rewriter.cc @@ -309,17 +309,15 @@ bool IsComputeCapabilityAndCudnnSupported( stream_executor::CudaComputeCapability cc, stream_executor::dnn::VersionInfo cudnn_version, stream_executor::dnn::VersionInfo supported_cudnn_version) { - if (!((cc.IsAtLeast(se::CudaComputeCapability::AMPERE) && cc.minor == 0) && - (cudnn_version >= supported_cudnn_version))) { - VLOG(2) << absl::StrFormat( - "CudnnFusedMHARewriter did not run. Unsupported compute " - "capability(==8.0) or cudnn version(>=%d.%d.%d)", - supported_cudnn_version.major_version(), - supported_cudnn_version.minor_version(), - supported_cudnn_version.patch()); - return false; + if (cc.IsAtLeastAmpere() && cudnn_version >= supported_cudnn_version) { + return true; } - return true; + VLOG(2) << absl::StrFormat( + "CudnnFusedMHARewriter did not run. Unsupported compute " + "capability(%s; should be >= 8.0) or cudnn version(%s; should be >= %s)", + cc.ToString(), cudnn_version.ToString(), + supported_cudnn_version.ToString()); + return false; } bool IsSupportedPrimitiveType(const HloInstruction* bmm) {