Skip to content

Commit

Permalink
Re-enable HoistLayoutConversion pattern and mixed-precision MMA for A…
Browse files Browse the repository at this point in the history
…mpere.

PiperOrigin-RevId: 621158262
  • Loading branch information
Moerafaat authored and tensorflower-gardener committed Apr 2, 2024
1 parent 3a531c9 commit e5c11f3
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 6 deletions.
47 changes: 47 additions & 0 deletions third_party/triton/cl609333259.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
This patch handles internal test failures. We can attempt to upstream this into
2 changes, but OpenAI might resist. For now, we should move this patch into the
internal ones. This is tracked here: b/331606551. These issues won't reproduce
upstream without removing a pass (which we do internally) that needs further
investigations (tracked here b/331360119).

diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
@@ -123,7 +115,8 @@ public:
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cvt.getType().cast<RankedTensorType>();
- if (!cvtTy.getEncoding().isa<DotOperandEncodingAttr>())
+ auto dotOpEnc = cvtTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
+ if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
@@ -138,6 +131,12 @@ public:
[](Type ty) { return ty.isa<RankedTensorType>(); }))
return failure();

+ // Quick handling to fix loading issues when computing the original
+ // bitwidth is unable to realize that there is a mixed-precision dot
+ // (hence kWidth = 1) but wants to hoist through the type conversion.
+ if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
+ return failure();
+
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) &&
@@ -150,6 +149,14 @@ public:
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

+ // Don't hoist through u1 -> fp casts as they aren't supported in
+ // ElementwiseOpToLLVM::reorderValues().
+ if (isa<arith::UIToFPOp>(src)) {
+ Type srcType = getElementTypeOrSelf(src->getOperand(0));
+ if (srcType.isInteger(1))
+ return failure();
+ }
+
// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
1 change: 1 addition & 0 deletions third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ def repo():
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
"//third_party/triton:cl619443019.patch",
"//third_party/triton:cl609333259.patch",
],
)
47 changes: 47 additions & 0 deletions third_party/xla/third_party/triton/cl609333259.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
This patch handles internal test failures. We can attempt to upstream this into
2 changes, but OpenAI might resist. For now, we should move this patch into the
internal ones. This is tracked here: b/331606551. These issues won't reproduce
upstream without removing a pass (which we do internally) that needs further
investigations (tracked here b/331360119).

diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp
@@ -123,7 +115,8 @@ public:
PatternRewriter &rewriter) const override {
// Only consider conversions to dot operand.
auto cvtTy = cvt.getType().cast<RankedTensorType>();
- if (!cvtTy.getEncoding().isa<DotOperandEncodingAttr>())
+ auto dotOpEnc = cvtTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
+ if (!dotOpEnc)
return failure();

auto src = cvt.getSrc().getDefiningOp();
@@ -138,6 +131,12 @@ public:
[](Type ty) { return ty.isa<RankedTensorType>(); }))
return failure();

+ // Quick handling to fix loading issues when computing the original
+ // bitwidth is unable to realize that there is a mixed-precision dot
+ // (hence kWidth = 1) but wants to hoist through the type conversion.
+ if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1)
+ return failure();
+
// Only consider custom conversions or arith ops.
// TODO(jlebar): Is this too restrictive?
if (!isa<FpToFpOp, BitcastOp>(src) &&
@@ -150,6 +149,14 @@ public:
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src))
return failure();

+ // Don't hoist through u1 -> fp casts as they aren't supported in
+ // ElementwiseOpToLLVM::reorderValues().
+ if (isa<arith::UIToFPOp>(src)) {
+ Type srcType = getElementTypeOrSelf(src->getOperand(0));
+ if (srcType.isInteger(1))
+ return failure();
+ }
+
// Check that the conversion is transitively dependent on a load, and all
// operations between the load and the conversion are layout preserving.
//
1 change: 1 addition & 0 deletions third_party/xla/third_party/triton/workspace.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ def repo():
"//third_party/triton:cl617812302.patch",
"//third_party/triton:cl619146327.patch",
"//third_party/triton:cl619443019.patch",
"//third_party/triton:cl609333259.patch",
],
)
9 changes: 6 additions & 3 deletions third_party/xla/xla/service/gpu/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ ENTRY e {
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

// Modify block_k back to 16 once b/331362083 is fixed.
TEST_F(GemmFusionAutotunerTest, DoNotRunAutotuningKernelSpillingRegisters) {
const std::string kHloText = R"(
HloModule m
Expand All @@ -435,7 +436,7 @@ ENTRY %e {
%get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
%convert = s8[4,12288]{1,0} parameter(1)
ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
})";

auto module = ParseAndReturnVerifiedModule(kHloText).value();
Expand All @@ -452,6 +453,7 @@ ENTRY %e {
"Compilation result discarded due to register spilling")));
}

// Modify block_k back to 16 once b/331362083 is fixed.
TEST_F(GemmFusionAutotunerTest,
DoNotFilterOutAutotuningKernelSpillingRegisters) {
const std::string kHloText = R"(
Expand All @@ -470,7 +472,7 @@ ENTRY %e {
%get-tuple-element.7020 = s8[12288,1536]{1,0} parameter(0)
%convert = s8[4,12288]{1,0} parameter(1)
ROOT %triton = s8[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %get-tuple-element.7020, s8[4,12288]{1,0} %convert), kind=kCustom, calls=%triton_gemm_dot,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"16","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1"}}}
})";

auto module = ParseAndReturnVerifiedModule(kHloText).value();
Expand All @@ -493,6 +495,7 @@ ENTRY %e {
EXPECT_NE(executable, nullptr);
}

// Modify block_k back to 16 once b/331362083 is fixed.
TEST_F(GemmFusionAutotunerTest, RunAutotuningKernelNotSpillingRegisters) {
const std::string kHloText = R"(
HloModule m
Expand All @@ -508,7 +511,7 @@ ENTRY %e {
%p0 = s8[12288,1536]{1,0} parameter(0)
%p1 = f16[4,12288]{1,0} parameter(1)
ROOT %triton_dot = f16[4,1536]{1,0} fusion(s8[12288,1536]{1,0} %p0, f16[4,12288]{1,0} %p1), kind=kCustom, calls=%triton_gemm_dot,
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"16","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
backend_config={"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"16","block_n":"32","block_k":"32","split_k":"1","num_stages":"1","num_warps":"2","num_ctas":"1"}}}
})";

auto module = ParseAndReturnVerifiedModule(kHloText).value();
Expand Down
8 changes: 6 additions & 2 deletions third_party/xla/xla/service/gpu/ir_emitter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,12 @@ absl::Status CreateTritonPipeline(
}

pm.addPass(mt::gpu::createOptimizeDotOperandsPass());
pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
// We need to disable this pass because it undoes the hoisting of dot_operand
// layout conversion done in
// triton/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp in
// HoistLayoutConversion pattern.
// Bug: b/331360119
// pm.addPass(mt::gpu::createRemoveLayoutConversionsPass());
pm.addPass(mt::gpu::createReduceDataDuplicationPass());
pm.addPass(mt::gpu::createReorderInstructionsPass());
pm.addPass(mlir::createCSEPass());
Expand Down Expand Up @@ -2759,7 +2764,6 @@ absl::StatusOr<TritonWrapperResult> TritonWrapper(
if (debug_options.xla_gpu_enable_triton_hopper()) {
// Set environment variables for consumption by Triton.
tsl::setenv("ENABLE_MMA_V3", "true", true /*overwrite*/);
tsl::setenv("ENABLE_PIPELINING", "true", true /*overwrite*/);
}

TF_ASSIGN_OR_RETURN(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ INSTANTIATE_TEST_SUITE_P(RewriteTestSuite, MixedTypeTest,
// TritonRewriteTest2Params{F32, F16},
// TritonRewriteTest2Params{F32, BF16},
MixTypeParams{S8, BF16, 24, 40, 8},
MixTypeParams{S8, F16, 80, 16, 32, 1e-3, 1e-6},
// Modify the case below to use k = 32 instead of
// 16 once b/331362083 is fixed.
MixTypeParams{S8, F16, 80, 32, 32, 1e-3, 1e-6},
MixTypeParams{F16, F32, 127, 3, 300, 1e-2, 1e-2},
MixTypeParams{F16, BF16, 544, 96, 16, 1e-3, 1e-3},
MixTypeParams{BF16, F32, 77, 500, 333, 3e-3, 3e-3},
Expand Down

0 comments on commit e5c11f3

Please sign in to comment.