Skip to content

Commit

Permalink
[XLA:GPU] Initial simplification of Triton Support test.
Browse files Browse the repository at this point in the history
This is a first change that demonstrates my intention for simplifying the Triton support tests. Once we agree on the general direction, I will submit additional CLs to simplify all tests.

PiperOrigin-RevId: 642308346
  • Loading branch information
dimitar-asenov authored and tensorflower-gardener committed Jun 11, 2024
1 parent 95d33fd commit 33ac847
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 31 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,7 @@ xla_test(
":matmul_utils",
":triton_fusion_analysis",
":triton_support",
":triton_test_utils",
"//xla:error_spec",
"//xla:shape_util",
"//xla:xla_data_proto_cc",
Expand All @@ -1280,7 +1281,6 @@ xla_test(
"//xla/hlo/utils:hlo_query",
"//xla/service:float_normalization",
"//xla/service:hlo_pass_pipeline",
"//xla/service/gpu/tests:gpu_codegen_test",
"//xla/stream_executor:device_description",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/status",
Expand Down
39 changes: 14 additions & 25 deletions third_party/xla/xla/service/gpu/triton_support_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ limitations under the License.
#include "xla/service/gpu/gpu_float_support.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
#include "xla/service/gpu/triton_fusion_analysis.h"
#include "xla/service/gpu/triton_test_utils.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/stream_executor/device_description.h"
#include "xla/xla.pb.h"
Expand All @@ -60,14 +60,8 @@ namespace xla {
namespace gpu {
namespace {

class TritonSupportTest : public GpuCodegenTest {
class TritonSupportTest : public TritonFilecheckTest {
public:
se::CudaComputeCapability GetCudaComputeCapability() {
return backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability();
}
absl::StatusOr<bool> ApplyFloatNormalization(HloModule* module) {
const GpuFloatSupport bf16_support(GetCudaComputeCapability(), BF16);
HloPassPipeline pipeline("hlo float normalization");
Expand Down Expand Up @@ -141,38 +135,33 @@ TEST_P(UnaryElementwiseTest, IsTritonSupportedExecutesCorrectlyForUnary) {
}

const std::string kHloTestTemplate = R"(
triton_gemm___computation {
parameter_0 = f32[15,33]{1,0} parameter(0)
parameter_1 = $0[33,68]{1,0} parameter(1)
unary = $0[33,68]{1,0} $1(parameter_1)
convert = f32[33,68]{1,0} convert(unary)
ROOT dot = f32[15,68]{1,0} dot(parameter_0, convert),
lhs_contracting_dims={1}, rhs_contracting_dims={0},
operand_precision={HIGH, HIGH}
triton_computation {
parameter_0 = $0[33,68]{1,0} parameter(0)
unary = $0[33,68]{1,0} $1(parameter_0)
ROOT convert = f32[33,68]{1,0} convert(unary)
}
ENTRY e {
parameter_0 = f32[15,33]{1,0} parameter(0)
parameter_1 = $0[33,68]{1,0} parameter(1)
ROOT triton_gemm = f32[15,68]{1,0} fusion(parameter_0, parameter_1),
kind=kCustom, calls=triton_gemm___computation,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm"}}
parameter_0 = $0[33,68]{1,0} parameter(0)
ROOT root_op = f32[33,68]{1,0} fusion(parameter_0),
kind=kCustom, calls=triton_computation,
backend_config={"fusion_backend_config":{"kind":"__triton"}}
})";
const std::string hlo_test = absl::Substitute(
kHloTestTemplate, primitive_util::LowercasePrimitiveTypeName(data_type),
HloOpcodeString(opcode));
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_test));
const HloComputation* computation =
module->GetComputationWithName("triton_gemm___computation");
module->GetComputationWithName("triton_computation");
ASSERT_TRUE(computation != nullptr);
const HloInstruction* instr =
hlo_query::GetFirstInstructionWithOpcode(*computation, opcode);
if (IsTritonSupportedInstruction(*instr, GetCudaComputeCapability())) {
float tolerance = getTolerance(data_type);
TF_EXPECT_OK(ApplyFloatNormalization(module.get()));
EXPECT_TRUE(RunAndCompareNoHloPasses(
std::move(module), ErrorSpec{/*aabs=*/tolerance, /*arel=*/tolerance}));
TF_EXPECT_OK(CreateTritonIrAndFileCheck(
*computation, /*config=*/{}, /*output_tile_sizes=*/{1, 32}, EmitGeneric,
"CHECK: tt.func @triton_fn"));
} else {
// TODO(b/331632717): update the check to use SymbolicTileAnalysis to avoid
// tiling failures and check triton emitter fails gracefully.
Expand Down
19 changes: 14 additions & 5 deletions third_party/xla/xla/service/gpu/triton_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/gpu/fusions/triton.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
Expand All @@ -41,6 +42,7 @@ limitations under the License.
#include "tsl/platform/statusor.h"

namespace xla::gpu {

bool TritonTest::SkipBF16Tests() {
if (std::holds_alternative<stream_executor::RocmComputeCapability>(
GpuComputeComp())) {
Expand Down Expand Up @@ -68,13 +70,20 @@ absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck(
absl::string_view triton_fusion_name, absl::string_view filecheck_pattern) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<VerifiedHloModule> verified_module,
ParseAndReturnVerifiedModule(hlo_text));
auto* comp = verified_module->GetComputationWithName(triton_fusion_name);
TF_RET_CHECK(comp != nullptr);
return CreateTritonIrAndFileCheck(*comp, config, output_tile_sizes, emitter,
filecheck_pattern);
}

absl::Status TritonFilecheckTest::CreateTritonIrAndFileCheck(
const HloComputation& computation, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view filecheck_pattern) {
auto* fusion = Cast<HloFusionInstruction>(computation.FusionInstruction());

auto* computation =
verified_module->GetComputationWithName(triton_fusion_name);
auto* fusion = Cast<HloFusionInstruction>(computation->FusionInstruction());
TF_RET_CHECK(computation != nullptr);
TF_ASSIGN_OR_RETURN(auto analysis,
TritonFusionAnalysis::Execute(*computation));
TritonFusionAnalysis::Execute(computation));

auto fusion_analysis = HloFusionAnalysis::Create(fusion, &device_desc());

Expand Down
7 changes: 7 additions & 0 deletions third_party/xla/xla/service/gpu/triton_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
#define XLA_SERVICE_GPU_TRITON_TEST_UTILS_H_

#include <cstdint>
#include <memory>
#include <vector>

#include "absl/status/status.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/service/gpu/ir_emitter_triton.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/tests/gpu_codegen_test.h"
Expand Down Expand Up @@ -57,6 +59,11 @@ class TritonFilecheckTest : public TritonTest {
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view triton_fusion_name,
absl::string_view filecheck_pattern);

absl::Status CreateTritonIrAndFileCheck(
const HloComputation& computation, const TritonGemmConfig& config,
std::vector<int64_t> output_tile_sizes, TritonIrEmitter emitter,
absl::string_view filecheck_pattern);
};

} // namespace xla::gpu
Expand Down

0 comments on commit 33ac847

Please sign in to comment.