Skip to content

Commit

Permalink
[XLA:GPU] Simplify TritonSupport tests by providing a standard ENTRY …
Browse files Browse the repository at this point in the history
…computation.

PiperOrigin-RevId: 644694885
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 19, 2024
1 parent cdafce8 commit 160e760
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 108 deletions.
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ cc_library(
deps = [
":gpu_device_info_for_tests",
":gpu_float_support",
":ir_emission_utils",
":ir_emitter_triton",
":matmul_utils",
"//xla:shape_util",
Expand All @@ -684,6 +685,7 @@ cc_library(
"//xla/stream_executor:device_description",
"//xla/tests:filecheck",
"//xla/tests:verified_hlo_module",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand All @@ -693,6 +695,7 @@ cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//llvm:ir_headers",
"@llvm-project//mlir:IR",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:statusor",
],
)
Expand Down
120 changes: 12 additions & 108 deletions third_party/xla/xla/service/gpu/triton_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,9 @@ using BitcastOrReshapeTest = TritonSupportTestWithParam;
TEST_P(BitcastOrReshapeTest, IsTritonSupportedBitcastOrReshape) {
auto [data_type, opcode] = GetParam();
const std::string kHloTestTemplate = R"(
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[1,16,4]{2,1,0} parameter(0)
ROOT bitcast_or_reshape = $0[64]{0} $1(parameter_0)
}
ENTRY e {
parameter_0 = $0[1,16,4]{2,1,0} parameter(0)
ROOT root_op = $0[64]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -120,17 +113,10 @@ TEST_P(UnaryElementwiseTest, IsTritonSupportedUnaryElementwise) {
}

const std::string kHloTestTemplate = R"(
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[33,68]{1,0} parameter(0)
unary = $0[33,68]{1,0} $1(parameter_0)
ROOT convert = f32[33,68]{1,0} convert(unary)
}
ENTRY e {
parameter_0 = $0[33,68]{1,0} parameter(0)
ROOT root_op = f32[33,68]{1,0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -184,18 +170,10 @@ TEST_P(BinaryElementwiseTest, IsTritonSupportedBinaryElementwise) {
}

const std::string kHloTestTemplate = R"(
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
ROOT binary = $0[11,63]{1,0} $1(parameter_0, parameter_1)
}
ENTRY e {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
ROOT triton_op = $0[11,63]{1,0} fusion(parameter_0, parameter_1),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -251,19 +229,11 @@ TEST_P(CompareTest, IsTritonSupportedCompare) {
}

const std::string kHloTestTemplate = R"(
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
compare = pred[11,63]{1,0} $1(parameter_0, parameter_1), direction=GE
ROOT convert = f32[11,63]{1,0} convert(compare)
}
ENTRY e {
parameter_0 = $0[11,63]{1,0} parameter(0)
parameter_1 = $0[11,63]{1,0} parameter(1)
ROOT triton_op = f32[11,63]{1,0} fusion(parameter_0, parameter_1),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -298,21 +268,12 @@ TEST_P(TernaryElementwiseTest, IsTritonSupportedTernaryElementwise) {
}

const std::string kHloTestTemplate = R"(
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[13,63]{1,0} parameter(0)
parameter_1 = $0[13,63]{1,0} parameter(1)
parameter_2 = pred[13,63]{1,0} parameter(2)
ternary = $0[13,63]{1,0} $1(parameter_2, parameter_0, parameter_1)
ROOT convert = f32[13,63]{1,0} convert(ternary)
}
ENTRY e {
parameter_0 = $0[13,63]{1,0} parameter(0)
parameter_1 = $0[13,63]{1,0} parameter(1)
parameter_2 = pred[13,63]{1,0} parameter(2)
ROOT triton_op = f32[13,63]{1,0} fusion(parameter_0, parameter_1, parameter_2),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -353,18 +314,10 @@ add {
ROOT add = $0[] add(Arg_0, Arg_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = $0[125,127]{1,0} parameter(0)
constant_0 = $0[] constant(0)
ROOT reduce = $0[125]{0} $1(parameter_0, constant_0), dimensions={1}, to_apply=add
}
ENTRY main {
parameter_0 = $0[125,127]{1,0} parameter(0)
ROOT triton_op = $0[125]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(
TestedInstruction ti,
Expand Down Expand Up @@ -406,19 +359,11 @@ add {
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[125,127]{1,0} parameter(0)
constant_0 = bf16[] constant(0)
convert_0 = f32[] convert(constant_0)
ROOT reduce = f32[125]{0} reduce(parameter_0, convert_0), dimensions={1}, to_apply=add
}
ENTRY main {
parameter_0 = f32[125,127]{1,0} parameter(0)
ROOT triton_op = f32[125]{0} fusion(parameter_0), kind=kCustom,
calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand All @@ -442,18 +387,10 @@ add {
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[2,125,127]{2,1,0} parameter(0)
constant_0 = f32[] constant(0)
ROOT reduce = f32[2]{0} reduce(parameter_0, constant_0), dimensions={1,2}, to_apply=add
}
ENTRY main {
parameter_0 = f32[2,125,127]{2,1,0} parameter(0)
ROOT triton_op = f32[2]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand All @@ -479,18 +416,10 @@ add {
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[125,127]{1,0} parameter(0)
constant_0 = f32[] constant(0)
ROOT reduce = f32[127]{0} reduce(parameter_0, constant_0), dimensions={0}, to_apply=add
}
ENTRY main {
parameter_0 = f32[125,127]{1,0} parameter(0)
ROOT triton_op = f32[127]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand Down Expand Up @@ -520,19 +449,11 @@ add {
ROOT pair = (f32[], f32[]) tuple(add_0, add_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[125,127] parameter(0)
constant_0 = f32[] constant(0)
tuple_0 = (f32[125]{0}, f32[125]{0}) reduce(parameter_0, parameter_0, constant_0, constant_0), dimensions={1}, to_apply=add
ROOT reduce = f32[125]{0} get-tuple-element(tuple_0), index=0
}
ENTRY main {
parameter_0 = f32[125,127]{1,0} parameter(0)
ROOT triton_op = f32[125]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand All @@ -557,19 +478,10 @@ add {
ROOT add = f32[] add(Arg_0, Arg_1)
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[125,127]{1,0} parameter(0)
init = f32[] parameter(1)
ROOT reduce = f32[125]{0} reduce(parameter_0, init), dimensions={1}, to_apply=add
}
ENTRY main {
parameter_0 = f32[125,127]{1,0} parameter(0)
parameter_1 = f32[] parameter(1)
ROOT triton_op = f32[125]{0} fusion(parameter_0, parameter_1),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand Down Expand Up @@ -599,18 +511,10 @@ custom_call {
ROOT custom_call = f32[] custom-call(Arg_0, Arg_1), custom_call_target="foo"
}
triton_computation {
ENTRY triton_computation {
parameter_0 = f32[125,127]{1,0} parameter(0)
constant_0 = f32[] constant(0)
ROOT reduce = f32[125]{0} reduce(parameter_0, constant_0), dimensions={1}, to_apply=custom_call
}
ENTRY main {
parameter_0 = f32[125,127]{1,0} parameter(0)
ROOT triton_op = f32[125]{0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":
{"kind":"__triton"}}
})";
TF_ASSERT_OK_AND_ASSIGN(TestedInstruction ti,
ParseTemplateAndGetInstruction(
Expand Down
50 changes: 50 additions & 0 deletions third_party/xla/xla/service/gpu/triton_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ limitations under the License.
#include <tuple>
#include <utility>
#include <variant>
#include <vector>

#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -41,13 +43,15 @@ limitations under the License.
#include "xla/service/float_normalization.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/model/tiled_hlo_computation.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/status_macros.h"
#include "xla/stream_executor/device_description.h"
#include "xla/tests/filecheck.h"
#include "xla/tests/verified_hlo_module.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"

namespace xla::gpu {
Expand Down Expand Up @@ -140,6 +144,46 @@ std::string TritonSupportTestParamsToString(
absl::StrReplaceAll(HloOpcodeString(opcode), {{"-", "_"}}));
}

namespace {

// This function does nothing if the input module already has an entry
// computation whose root is a fusion. Otherwise, creates a new entry
// computation whose root is a fusion instruction that calls the original entry
// computation. The new fusion instruction uses the generic Triton backend kind.
absl::Status ConvertEntryToTritonFusion(HloModule* module) {
if (module->entry_computation()->root_instruction()->opcode() ==
HloOpcode::kFusion) {
return absl::OkStatus();
}
auto builder = HloComputation::Builder("entry");
std::vector<HloInstruction*> params;
for (auto& param : module->entry_computation()->parameter_instructions()) {
TF_ASSIGN_OR_RETURN(
auto param_clone,
builder.AddParameter(HloInstruction::CreateParameter(
param->parameter_number(), param->shape(),
absl::StrCat("param_", param->parameter_number()))));
params.push_back(param_clone);
}

auto fusion = builder.AddInstruction(HloInstruction::CreateFusion(
module->entry_computation()->root_instruction()->shape(),
HloInstruction::FusionKind::kCustom, params,
module->entry_computation()));

gpu::GpuBackendConfig gpu_config;
gpu_config.mutable_fusion_backend_config()->set_kind(kTritonFusionKind);
TF_RETURN_IF_ERROR(fusion->set_backend_config(gpu_config));

auto new_entry =
module->AddComputationAndUnifyNamesAndIds(builder.Build(),
/*is_entry=*/false);
module->ReplaceEntryComputation(new_entry);
return absl::OkStatus();
}

} // namespace

absl::StatusOr<TritonSupportTest::TestedInstruction>
TritonSupportTest::ParseTemplateAndGetInstruction(
absl::string_view hlo_template, xla::PrimitiveType data_type,
Expand All @@ -149,8 +193,14 @@ TritonSupportTest::ParseTemplateAndGetInstruction(
HloOpcodeString(opcode));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
TF_RETURN_IF_ERROR(ConvertEntryToTritonFusion(module.get()));
const HloComputation* computation =
module->GetComputationWithName("triton_computation");
if (computation == module->entry_computation()) {
return absl::InvalidArgumentError(
"The `triton_computation` and the module's entry computation cannot be "
"the same.");
}
const HloFusionInstruction* fusion = DynCast<HloFusionInstruction>(
module->entry_computation()->root_instruction());
if (fusion == nullptr) {
Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/service/gpu/triton_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class TritonSupportTest : public TritonFilecheckTest {
// The provided template must contain a computation called
// `triton_computation`. If the template contains parameters $0 and $1, they
// will be replaced with the data type and opcode respectively.
// If the template's entry computation does not have a root fusion
// instruction, a new entry computation will be created. The new computation
// will have a root fusion instruction that has the same parameters as the
// `triton_computation` and contains a fusion instruction that calls the
// `triton_computation` with the generic Triton emitter. Tests that need
// the `__triton_gemm` backend kind should provide their own ENTRY
// computation.
absl::StatusOr<TestedInstruction> ParseTemplateAndGetInstruction(
absl::string_view hlo_template, xla::PrimitiveType data_type,
xla::HloOpcode opcode);
Expand Down

0 comments on commit 160e760

Please sign in to comment.