-
Notifications
You must be signed in to change notification settings - Fork 74.1k
/
cl566223642.patch
63 lines (60 loc) · 3.45 KB
/
cl566223642.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
==== triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp#4 - /google/src/cloud/bchetioui/mlir_5cf714bb2f75552b10e1eb62fd07aec4b6033881_1695024466/triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp ====
# action=edit type=text
--- triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp 2023-08-03 10:50:52.000000000 -0700
+++ triton/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp 2023-09-18 01:11:06.000000000 -0700
@@ -321,9 +321,9 @@
Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op,
std::stack<Operation *> &eraser) {
// Generate new iteration operands and set rewrited information
- SmallVector<Value> oldIterOperands = op.getIterOperands();
- SmallVector<Value> newIterOperands = op.getIterOperands();
- for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size;
+ SmallVector<Value> oldIterOperands = llvm::to_vector(op.getInitArgs());
+ SmallVector<Value> newIterOperands = llvm::to_vector(op.getInitArgs());
+ for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size;
++i, ++oldI) {
if (!triton::isTensorPointerType(newIterOperands[i].getType()))
continue;
@@ -346,7 +346,7 @@
// mapping. It may refer to a value in the old loop, but we will rewrite it
// later
IRMapping mapping;
- for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands();
+ for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size();
++i, ++oldI) {
auto oldRegionIterArg = op.getRegionIterArg(oldI);
if (triton::isTensorPointerType(oldRegionIterArg.getType())) {
@@ -373,7 +373,7 @@
}
// Replace later usages
- assert(op.getNumResults() == op.getNumIterOperands());
+ assert(op.getNumResults() == op.getInitArgs().size());
for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) {
auto oldResult = op.getResult(oldI);
if (triton::isTensorPointerType(oldResult.getType())) {
==== triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp#15 - /google/src/cloud/bchetioui/mlir_5cf714bb2f75552b10e1eb62fd07aec4b6033881_1695024466/triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-08-23 07:42:50.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp 2023-09-18 01:14:11.000000000 -0700
@@ -718,8 +718,7 @@
// We need this to update operands for yield
// original block arg => new arg's idx
SmallVector<Value> newLoopArgs;
- for (auto v : forOp.getIterOperands())
- newLoopArgs.push_back(v);
+ for (auto v : forOp.getInitArgs()) newLoopArgs.push_back(v);
bufferIdx = newLoopArgs.size();
for (auto loadOp : validLoads)
==== triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp#7 - /google/src/cloud/bchetioui/mlir_5cf714bb2f75552b10e1eb62fd07aec4b6033881_1695024466/triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp ====
# action=edit type=text
--- triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-07-04 07:30:46.000000000 -0700
+++ triton/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp 2023-09-18 01:14:11.000000000 -0700
@@ -269,8 +269,7 @@
OpBuilder builder(forOp);
SmallVector<Value> loopArgs;
- for (auto v : forOp.getIterOperands())
- loopArgs.push_back(v);
+ for (auto v : forOp.getInitArgs()) loopArgs.push_back(v);
for (Value dot : dots) {
loopArgs.push_back(
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()]);