Skip to content

Commit

Permalink
[XLA:GPU] Let TritonFusion depend on CUDA or ROCm headers (transitive…
Browse files Browse the repository at this point in the history
…ly).

So far, we guard the dep with `#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM` and return an unimplemented error otherwise. However, having `"triton"` in the build graph makes no sense if neither CUDA nor ROCm toolkits are available at build time. This PR moves the `if_gpu_is_configured` branching upwards in the build graph to `"ir_emitter_unnested`".

PiperOrigin-RevId: 636875057
  • Loading branch information
thomasjoerg authored and tensorflower-gardener committed May 24, 2024
1 parent e6538d9 commit 4549f45
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,6 @@ cc_library(
"//xla/service:custom_call_target_registry",
"//xla/service:global_device_id",
"//xla/service:name_uniquer",
"//xla/service/gpu/fusions",
"//xla/service/gpu/fusions:fusion_emitter",
"//xla/service/gpu/fusions:thunk_util",
"//xla/service/gpu/kernels:custom_kernel",
Expand Down Expand Up @@ -434,6 +433,7 @@ cc_library(
"@triton//:TritonDialects",
] + if_gpu_is_configured([
":ir_emitter_triton",
"//xla/service/gpu/fusions",
"//xla/service/gpu/runtime:cholesky_thunk",
"//xla/service/gpu/runtime:cub_sort_thunk",
"//xla/service/gpu/runtime:gpublas_lt_matmul_thunk",
Expand Down
1 change: 0 additions & 1 deletion third_party/xla/xla/service/gpu/fusions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,6 @@ cc_library(
name = "triton",
srcs = ["triton.cc"],
hdrs = ["triton.h"],
local_defines = if_cuda_is_configured(["GOOGLE_CUDA=1"]),
deps = [
":fusion_emitter",
"//xla:shape_util",
Expand Down
11 changes: 1 addition & 10 deletions third_party/xla/xla/service/gpu/fusions/triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ limitations under the License.
#include "xla/service/gpu/hlo_traversal.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_context.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/kernel_arguments.h"
#include "xla/service/gpu/kernel_reuse_cache.h"
#include "xla/service/gpu/launch_dimensions.h"
Expand All @@ -48,12 +49,6 @@ limitations under the License.
#include "xla/status_macros.h"
#include "tsl/platform/statusor.h"

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include "xla/service/gpu/ir_emitter_triton.h"
#else
#include "absl/status/status.h"
#endif

namespace xla {
namespace gpu {
namespace {
Expand Down Expand Up @@ -104,7 +99,6 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
IrEmitterContext& ir_emitter_context,
const HloFusionInstruction& fusion) const {
llvm::IRBuilder builder(ir_emitter_context.llvm_module()->getContext());
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
VLOG(3) << fusion.ToString();
std::string suggested_kernel_name = std::string(fusion.name());
TF_ASSIGN_OR_RETURN(
Expand Down Expand Up @@ -217,9 +211,6 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
entry->launch_dimensions, entry->cluster_dim, entry->shmem_bytes));

return result;
#else
return absl::UnimplementedError("Triton support requires CUDA or ROCm");
#endif
}

std::optional<LaunchDimensions> TritonFusion::launch_dimensions() const {
Expand Down

0 comments on commit 4549f45

Please sign in to comment.