-
Notifications
You must be signed in to change notification settings - Fork 74k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Re-enable HoistLayoutConversion pattern and mixed-precision MMA for A…
…mpere. PiperOrigin-RevId: 621158262
- Loading branch information
1 parent
3a531c9
commit e5c11f3
Showing
7 changed files
with
111 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
This patch handles internal test failures. We can attempt to upstream this into | ||
2 changes, but OpenAI might resist. For now, we should move this patch into the | ||
internal ones. This is tracked here: b/331606551. These issues won't reproduce | ||
upstream without removing a pass (which we do internally) that needs further | ||
investigations (tracked here b/331360119). | ||
|
||
diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
+++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
@@ -123,7 +115,8 @@ public: | ||
PatternRewriter &rewriter) const override { | ||
// Only consider conversions to dot operand. | ||
auto cvtTy = cvt.getType().cast<RankedTensorType>(); | ||
- if (!cvtTy.getEncoding().isa<DotOperandEncodingAttr>()) | ||
+ auto dotOpEnc = cvtTy.getEncoding().dyn_cast<DotOperandEncodingAttr>(); | ||
+ if (!dotOpEnc) | ||
return failure(); | ||
|
||
auto src = cvt.getSrc().getDefiningOp(); | ||
@@ -138,6 +131,12 @@ public: | ||
[](Type ty) { return ty.isa<RankedTensorType>(); })) | ||
return failure(); | ||
|
||
+ // Quick handling to fix loading issues when computing the original | ||
+ // bitwidth is unable to realize that there is a mixed-precision dot | ||
+ // (hence kWidth = 1) but wants to hoist through the type conversion. | ||
+ if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1) | ||
+ return failure(); | ||
+ | ||
// Only consider custom conversions or arith ops. | ||
// TODO(jlebar): Is this too restrictive? | ||
if (!isa<FpToFpOp, BitcastOp>(src) && | ||
@@ -150,6 +149,14 @@ public: | ||
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src)) | ||
return failure(); | ||
|
||
+ // Don't hoist through u1 -> fp casts as they aren't supported in | ||
+ // ElementwiseOpToLLVM::reorderValues(). | ||
+ if (isa<arith::UIToFPOp>(src)) { | ||
+ Type srcType = getElementTypeOrSelf(src->getOperand(0)); | ||
+ if (srcType.isInteger(1)) | ||
+ return failure(); | ||
+ } | ||
+ | ||
// Check that the conversion is transitively dependent on a load, and all | ||
// operations between the load and the conversion are layout preserving. | ||
// |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
This patch handles internal test failures. We can attempt to upstream this into | ||
2 changes, but OpenAI might resist. For now, we should move this patch into the | ||
internal ones. This is tracked here: b/331606551. These issues won't reproduce | ||
upstream without removing a pass (which we do internally) that needs further | ||
investigations (tracked here b/331360119). | ||
|
||
diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
--- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
+++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp | ||
@@ -123,7 +115,8 @@ public: | ||
PatternRewriter &rewriter) const override { | ||
// Only consider conversions to dot operand. | ||
auto cvtTy = cvt.getType().cast<RankedTensorType>(); | ||
- if (!cvtTy.getEncoding().isa<DotOperandEncodingAttr>()) | ||
+ auto dotOpEnc = cvtTy.getEncoding().dyn_cast<DotOperandEncodingAttr>(); | ||
+ if (!dotOpEnc) | ||
return failure(); | ||
|
||
auto src = cvt.getSrc().getDefiningOp(); | ||
@@ -138,6 +131,12 @@ public: | ||
[](Type ty) { return ty.isa<RankedTensorType>(); })) | ||
return failure(); | ||
|
||
+ // Quick handling to fix loading issues when computing the original | ||
+ // bitwidth is unable to realize that there is a mixed-precision dot | ||
+ // (hence kWidth = 1) but wants to hoist through the type conversion. | ||
+ if (isa<arith::ExtFOp>(src) && dotOpEnc.getKWidth() == 1) | ||
+ return failure(); | ||
+ | ||
// Only consider custom conversions or arith ops. | ||
// TODO(jlebar): Is this too restrictive? | ||
if (!isa<FpToFpOp, BitcastOp>(src) && | ||
@@ -150,6 +149,14 @@ public: | ||
if (isa<arith::TruncIOp, arith::TruncFOp, arith::SelectOp>(src)) | ||
return failure(); | ||
|
||
+ // Don't hoist through u1 -> fp casts as they aren't supported in | ||
+ // ElementwiseOpToLLVM::reorderValues(). | ||
+ if (isa<arith::UIToFPOp>(src)) { | ||
+ Type srcType = getElementTypeOrSelf(src->getOperand(0)); | ||
+ if (srcType.isInteger(1)) | ||
+ return failure(); | ||
+ } | ||
+ | ||
// Check that the conversion is transitively dependent on a load, and all | ||
// operations between the load and the conversion are layout preserving. | ||
// |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters