From b08e2063af508844dc5c610c9774d06fcc312578 Mon Sep 17 00:00:00 2001 From: Elfie Guo Date: Fri, 24 May 2024 04:59:00 -0700 Subject: [PATCH] PR #13020: Support lowering XLA clamp instruction to cuDNN. Imported from GitHub PR https://github.com/openxla/xla/pull/13020 Support lowering XLA clamp instruction to cuDNN. cc @sergachev Copybara import of the project: -- 47dc71f2a0d5887461a0b7d985328442e0e8da2f by Elfie Guo : Support lowering clamp instruction to cuDNN. Merging this change closes #13020 PiperOrigin-RevId: 636875273 --- .../xla/service/gpu/cudnn_fusion_compiler.cc | 122 ++++++++++++------ .../xla/xla/service/gpu/fusions/cudnn_test.cc | 26 ++++ .../xla/xla/service/gpu/triton_support.cc | 2 +- 3 files changed, 106 insertions(+), 44 deletions(-) diff --git a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc index 3782301595a36a..ee8374d46cdc95 100644 --- a/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc +++ b/third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc @@ -349,6 +349,30 @@ HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) { } } +std::optional> +HandleClampToCudnnGraph( + const HloInstruction& hlo, graph::Graph& graph, + absl::flat_hash_map> + hlo_to_cudnn, + fe::DataType_t data_type, fe::DataType_t compute_dtype) { + CHECK(hlo.opcode() == HloOpcode::kClamp) + << "HLO is not a clamp: " << hlo.ToShortString(); + CHECK(hlo.operands().size() == 3) + << "Clamp requires to have 3 operands: " << hlo.ToShortString(); + // clamp = max(lower, min(value, upper)); + const auto min_attrs = graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MIN) + .set_compute_data_type(compute_dtype); + std::shared_ptr min_tensor = graph.pointwise( + hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs); + min_tensor->set_data_type(data_type).set_name(std::string(hlo.name())); + const auto max_attrs = graph::Pointwise_attributes() + .set_mode(fe::PointwiseMode_t::MAX) + .set_compute_data_type(compute_dtype); + return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs); +} + // Traverses fusion computations and creates cuDNN graphs out of them. absl::StatusOr> HloFusionToCuDnnGraph( const HloFusionInstruction& fusion) { @@ -407,6 +431,11 @@ absl::StatusOr> HloFusionToCuDnnGraph( auto operand = [&hlo_to_cudnn, &hlo](int i) { return hlo_to_cudnn[hlo->operand(i)]; }; + const auto data_type = ToCudnnDataType(hlo->shape().element_type()); + if (!data_type.has_value()) { + VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type(); + return std::nullopt; + } if (hlo->opcode() == HloOpcode::kParameter) { CHECK(hlo_to_cudnn.contains(hlo)); continue; @@ -442,53 +471,65 @@ absl::StatusOr> HloFusionToCuDnnGraph( // All these are accounted for separately as transformations of strides. hlo_to_cudnn[hlo] = operand(0); } else if (hlo->IsElementwise()) { - const auto mode = GetElementwiseMode(*hlo); - if (!mode.has_value()) { - VLOG(3) << "Unsupported elementwise operation."; - return std::nullopt; - } const auto compute_dtype = GetComputeDataType(hlo->shape().element_type()); if (!compute_dtype.has_value()) { return std::nullopt; } - const auto attrs = graph::Pointwise_attributes() - .set_mode(mode.value()) - .set_compute_data_type(compute_dtype.value()); - if (hlo->operand_count() == 1) { - hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs); - // Sets the dimensions for unary ops whose operands are broadcast for - // cuDNN to infer its inputs' shapes. constant has dimension [1] while - // cuDNN requires constant to have dimension [1,1,1]. Not setting output - // of the unary shapes results in the rejection of the cuDNN graph. - if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) { - const auto scope = adapter->analysis_.QueryInstructionScope(*hlo); - std::vector dimensions; - std::vector strides; - if (!scope.has_value()) { - LOG(FATAL) << "No scope for instruction: " << hlo->ToShortString(); + if (hlo->opcode() == HloOpcode::kClamp) { + const auto clamp = + HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn, + data_type.value(), compute_dtype.value()); + if (!clamp.has_value()) { + return std::nullopt; + } + hlo_to_cudnn[hlo] = clamp.value(); + } else { + const auto mode = GetElementwiseMode(*hlo); + if (!mode.has_value()) { + VLOG(3) << "Unsupported elementwise operation."; + return std::nullopt; + } + const auto attrs = graph::Pointwise_attributes() + .set_mode(mode.value()) + .set_compute_data_type(compute_dtype.value()); + if (hlo->operand_count() == 1) { + hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs); + // Sets the dimensions for unary ops whose operands are broadcast + // for cuDNN to infer its inputs' shapes. constant has dimension [1] + // while cuDNN requires constant to have dimension [1,1,1]. Not + // setting output of the unary shapes results in the rejection of + // the cuDNN graph. + if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) { + const auto scope = adapter->analysis_.QueryInstructionScope(*hlo); + std::vector dimensions; + std::vector strides; + if (!scope.has_value()) { + LOG(FATAL) << "No scope for instruction: " + << hlo->ToShortString(); + } + if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions, + strides)) { + VLOG(3) << "Unsupported hlo for querying dimensions: " + << hlo->ToShortString(); + } else { + hlo_to_cudnn[hlo]->set_dim(dimensions); + } } - if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions, - strides)) { - VLOG(3) << "Unsupported hlo for querying dimensions: " - << hlo->ToShortString(); - } else { - hlo_to_cudnn[hlo]->set_dim(dimensions); + } else if (hlo->operand_count() == 2) { + hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs); + } else if (hlo->operand_count() == 3) { + if (hlo->opcode() != HloOpcode::kSelect) { + VLOG(3) << "Unexpected ternary operation: " << hlo->ToString(); + return std::nullopt; } - } - } else if (hlo->operand_count() == 2) { - hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs); - } else if (hlo->operand_count() == 3) { - if (hlo->opcode() != HloOpcode::kSelect) { - VLOG(3) << "Unexpected ternary operation: " << hlo->ToString(); + // Operand order for select differs between HLO and cuDNN. + hlo_to_cudnn[hlo] = + graph.pointwise(operand(1), operand(2), operand(0), attrs); + } else { + VLOG(3) << "Unimplemented elementwise operation."; return std::nullopt; } - // Operand order for select differs between HLO and cuDNN. - hlo_to_cudnn[hlo] = - graph.pointwise(operand(1), operand(2), operand(0), attrs); - } else { - VLOG(3) << "Unimplemented elementwise operation."; - return std::nullopt; } } else if (hlo->opcode() == HloOpcode::kDot) { const auto compute_dtype = @@ -508,11 +549,6 @@ absl::StatusOr> HloFusionToCuDnnGraph( VLOG(3) << "Creation of the operation failed."; return std::nullopt; } - const auto data_type = ToCudnnDataType(hlo->shape().element_type()); - if (!data_type.has_value()) { - VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type(); - return std::nullopt; - } hlo_to_cudnn[hlo] ->set_data_type(data_type.value()) .set_name(std::string(hlo->name())); diff --git a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc index 051040693daed4..4234c935954003 100644 --- a/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc +++ b/third_party/xla/xla/service/gpu/fusions/cudnn_test.cc @@ -556,6 +556,32 @@ ENTRY e { ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); } +TEST_F(CuDnnFusionLevel2Test, ClampExecutesCorrectly) { + EXPECT_TRUE(RunAndCompare(R"( +fusion1 { + x = bf16[16,32] parameter(0) + y = bf16[32,16] parameter(1) + x_const_lower = bf16[] constant(3e-3) + x_const_upper = bf16[] constant(1e-1) + y_const_lower = bf16[] constant(3e-3) + y_const_upper = bf16[] constant(1e-1) + x_const_bcast_lower = bf16[16,32] broadcast(x_const_lower), dimensions={} + x_const_bcast_upper = bf16[16,32] broadcast(x_const_upper), dimensions={} + y_const_bcast_lower = bf16[32,16] broadcast(y_const_lower), dimensions={} + y_const_bcast_upper = bf16[32,16] broadcast(y_const_upper), dimensions={} + x_clamp = bf16[16,32] clamp(x_const_bcast_lower, x, x_const_bcast_upper) + y_clamp = bf16[32,16] clamp(y_const_bcast_lower, y, y_const_bcast_upper) + ROOT dot_a = f32[16,16] dot(x_clamp, y_clamp), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +ENTRY e { + p0 = bf16[16,32] parameter(0) + p1 = bf16[32,16] parameter(1) + ROOT _ = f32[16,16] fusion(p0, p1), kind=kCustom, calls=fusion1, + backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}} +})", + ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + TEST_F(CuDnnFusionLevel2Test, DotF8ExecutesCorrectly) { EXPECT_TRUE(RunAndCompare(R"( diff --git a/third_party/xla/xla/service/gpu/triton_support.cc b/third_party/xla/xla/service/gpu/triton_support.cc index 05eea435bd70dd..155ef105d645fe 100644 --- a/third_party/xla/xla/service/gpu/triton_support.cc +++ b/third_party/xla/xla/service/gpu/triton_support.cc @@ -123,7 +123,7 @@ std::vector TritonSupportedBinaryElementwise( std::vector TritonSupportedTernaryElementwise( PrimitiveType element_type) { - return {HloOpcode::kSelect}; + return {HloOpcode::kSelect, HloOpcode::kClamp}; } bool IsTritonSupportedElementwise(HloOpcode opcode,