Skip to content

Commit

Permalink
PR #13020: Support lowering XLA clamp instruction to cuDNN.
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#13020

Support lowering XLA clamp instruction to cuDNN.
cc @sergachev
Copybara import of the project:

--
47dc71f2a0d5887461a0b7d985328442e0e8da2f by Elfie Guo <elfieg@nvidia.com>:

Support lowering clamp instruction to cuDNN.

Merging this change closes #13020

PiperOrigin-RevId: 636875273
  • Loading branch information
elfiegg authored and tensorflower-gardener committed May 24, 2024
1 parent 4549f45 commit b08e206
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 44 deletions.
122 changes: 79 additions & 43 deletions third_party/xla/xla/service/gpu/cudnn_fusion_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,30 @@ HandleConstantHloToCudnnGraph(const HloInstruction& hlo, graph::Graph& graph) {
}
}

std::optional<std::shared_ptr<graph::Tensor_attributes>>
HandleClampToCudnnGraph(
const HloInstruction& hlo, graph::Graph& graph,
absl::flat_hash_map<const HloInstruction*,
std::shared_ptr<graph::Tensor_attributes>>
hlo_to_cudnn,
fe::DataType_t data_type, fe::DataType_t compute_dtype) {
CHECK(hlo.opcode() == HloOpcode::kClamp)
<< "HLO is not a clamp: " << hlo.ToShortString();
CHECK(hlo.operands().size() == 3)
<< "Clamp requires to have 3 operands: " << hlo.ToShortString();
// clamp = max(lower, min(value, upper));
const auto min_attrs = graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::MIN)
.set_compute_data_type(compute_dtype);
std::shared_ptr<graph::Tensor_attributes> min_tensor = graph.pointwise(
hlo_to_cudnn[hlo.operand(1)], hlo_to_cudnn[hlo.operand(2)], min_attrs);
min_tensor->set_data_type(data_type).set_name(std::string(hlo.name()));
const auto max_attrs = graph::Pointwise_attributes()
.set_mode(fe::PointwiseMode_t::MAX)
.set_compute_data_type(compute_dtype);
return graph.pointwise(min_tensor, hlo_to_cudnn[hlo.operand(0)], max_attrs);
}

// Traverses fusion computations and creates cuDNN graphs out of them.
absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
const HloFusionInstruction& fusion) {
Expand Down Expand Up @@ -407,6 +431,11 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
auto operand = [&hlo_to_cudnn, &hlo](int i) {
return hlo_to_cudnn[hlo->operand(i)];
};
const auto data_type = ToCudnnDataType(hlo->shape().element_type());
if (!data_type.has_value()) {
VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
return std::nullopt;
}
if (hlo->opcode() == HloOpcode::kParameter) {
CHECK(hlo_to_cudnn.contains(hlo));
continue;
Expand Down Expand Up @@ -442,53 +471,65 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
// All these are accounted for separately as transformations of strides.
hlo_to_cudnn[hlo] = operand(0);
} else if (hlo->IsElementwise()) {
const auto mode = GetElementwiseMode(*hlo);
if (!mode.has_value()) {
VLOG(3) << "Unsupported elementwise operation.";
return std::nullopt;
}
const auto compute_dtype =
GetComputeDataType(hlo->shape().element_type());
if (!compute_dtype.has_value()) {
return std::nullopt;
}
const auto attrs = graph::Pointwise_attributes()
.set_mode(mode.value())
.set_compute_data_type(compute_dtype.value());
if (hlo->operand_count() == 1) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
// Sets the dimensions for unary ops whose operands are broadcast for
// cuDNN to infer its inputs' shapes. constant has dimension [1] while
// cuDNN requires constant to have dimension [1,1,1]. Not setting output
// of the unary shapes results in the rejection of the cuDNN graph.
if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
if (!scope.has_value()) {
LOG(FATAL) << "No scope for instruction: " << hlo->ToShortString();
if (hlo->opcode() == HloOpcode::kClamp) {
const auto clamp =
HandleClampToCudnnGraph(*hlo, graph, hlo_to_cudnn,
data_type.value(), compute_dtype.value());
if (!clamp.has_value()) {
return std::nullopt;
}
hlo_to_cudnn[hlo] = clamp.value();
} else {
const auto mode = GetElementwiseMode(*hlo);
if (!mode.has_value()) {
VLOG(3) << "Unsupported elementwise operation.";
return std::nullopt;
}
const auto attrs = graph::Pointwise_attributes()
.set_mode(mode.value())
.set_compute_data_type(compute_dtype.value());
if (hlo->operand_count() == 1) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), attrs);
// Sets the dimensions for unary ops whose operands are broadcast
// for cuDNN to infer its inputs' shapes. constant has dimension [1]
// while cuDNN requires constant to have dimension [1,1,1]. Not
// setting output of the unary shapes results in the rejection of
// the cuDNN graph.
if (hlo->operand(0)->opcode() == HloOpcode::kBroadcast) {
const auto scope = adapter->analysis_.QueryInstructionScope(*hlo);
std::vector<int64_t> dimensions;
std::vector<int64_t> strides;
if (!scope.has_value()) {
LOG(FATAL) << "No scope for instruction: "
<< hlo->ToShortString();
}
if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
strides)) {
VLOG(3) << "Unsupported hlo for querying dimensions: "
<< hlo->ToShortString();
} else {
hlo_to_cudnn[hlo]->set_dim(dimensions);
}
}
if (!adapter->DimensionsAndStrides(*hlo, scope.value(), dimensions,
strides)) {
VLOG(3) << "Unsupported hlo for querying dimensions: "
<< hlo->ToShortString();
} else {
hlo_to_cudnn[hlo]->set_dim(dimensions);
} else if (hlo->operand_count() == 2) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
} else if (hlo->operand_count() == 3) {
if (hlo->opcode() != HloOpcode::kSelect) {
VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
return std::nullopt;
}
}
} else if (hlo->operand_count() == 2) {
hlo_to_cudnn[hlo] = graph.pointwise(operand(0), operand(1), attrs);
} else if (hlo->operand_count() == 3) {
if (hlo->opcode() != HloOpcode::kSelect) {
VLOG(3) << "Unexpected ternary operation: " << hlo->ToString();
// Operand order for select differs between HLO and cuDNN.
hlo_to_cudnn[hlo] =
graph.pointwise(operand(1), operand(2), operand(0), attrs);
} else {
VLOG(3) << "Unimplemented elementwise operation.";
return std::nullopt;
}
// Operand order for select differs between HLO and cuDNN.
hlo_to_cudnn[hlo] =
graph.pointwise(operand(1), operand(2), operand(0), attrs);
} else {
VLOG(3) << "Unimplemented elementwise operation.";
return std::nullopt;
}
} else if (hlo->opcode() == HloOpcode::kDot) {
const auto compute_dtype =
Expand All @@ -508,11 +549,6 @@ absl::StatusOr<std::optional<se::gpu::CudnnGraph>> HloFusionToCuDnnGraph(
VLOG(3) << "Creation of the operation failed.";
return std::nullopt;
}
const auto data_type = ToCudnnDataType(hlo->shape().element_type());
if (!data_type.has_value()) {
VLOG(3) << "Unimplemented data type: " << hlo->shape().element_type();
return std::nullopt;
}
hlo_to_cudnn[hlo]
->set_data_type(data_type.value())
.set_name(std::string(hlo->name()));
Expand Down
26 changes: 26 additions & 0 deletions third_party/xla/xla/service/gpu/fusions/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,32 @@ ENTRY e {
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionLevel2Test, ClampExecutesCorrectly) {
EXPECT_TRUE(RunAndCompare(R"(
fusion1 {
x = bf16[16,32] parameter(0)
y = bf16[32,16] parameter(1)
x_const_lower = bf16[] constant(3e-3)
x_const_upper = bf16[] constant(1e-1)
y_const_lower = bf16[] constant(3e-3)
y_const_upper = bf16[] constant(1e-1)
x_const_bcast_lower = bf16[16,32] broadcast(x_const_lower), dimensions={}
x_const_bcast_upper = bf16[16,32] broadcast(x_const_upper), dimensions={}
y_const_bcast_lower = bf16[32,16] broadcast(y_const_lower), dimensions={}
y_const_bcast_upper = bf16[32,16] broadcast(y_const_upper), dimensions={}
x_clamp = bf16[16,32] clamp(x_const_bcast_lower, x, x_const_bcast_upper)
y_clamp = bf16[32,16] clamp(y_const_bcast_lower, y, y_const_bcast_upper)
ROOT dot_a = f32[16,16] dot(x_clamp, y_clamp), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = bf16[16,32] parameter(0)
p1 = bf16[32,16] parameter(1)
ROOT _ = f32[16,16] fusion(p0, p1), kind=kCustom, calls=fusion1,
backend_config={"fusion_backend_config": {kind: "__cudnn$fusion"}}
})",
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionLevel2Test, DotF8ExecutesCorrectly) {
EXPECT_TRUE(RunAndCompare(R"(
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/triton_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ std::vector<HloOpcode> TritonSupportedBinaryElementwise(

std::vector<HloOpcode> TritonSupportedTernaryElementwise(
PrimitiveType element_type) {
return {HloOpcode::kSelect};
return {HloOpcode::kSelect, HloOpcode::kClamp};
}

bool IsTritonSupportedElementwise(HloOpcode opcode,
Expand Down

0 comments on commit b08e206

Please sign in to comment.