-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern #141613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] fix OuterUnitDims linalg.pack decomposition pattern #141613
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Christopher McGirr (chrsmcgrr) ChangesGiven the following example:
We would generate an invalid transpose operation because the calculated permutation would be The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of Full diff: https://github.com/llvm/llvm-project/pull/141613.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 8718c57b9e86c..7b6c8243d1040 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1205,16 +1205,23 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
- // Two assumptions are made:
- // 1. All outer dims are 1 - the corresponding transposition doesn't matter.
- // 2. Inner dims position correspond to the trailing `numTiles` dims.
- SmallVector<int64_t> tilesPermNormalized =
- getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos());
+ // Assumptions made:
+ // 1. Inner dims position correspond to the trailing `numTiles` dims.
SmallVector<int64_t> srcPermForTranspose;
- for (int64_t i = 0; i < (srcRank - numTiles); i++)
+ ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
+ for (int64_t i = 0; i < srcRank; i++) {
+ // As we assume the trailing dimensions of the inner dim position correspond
+ // to the trailing indices of the transpose permutation, we need to
+ // calculate the remaining indicies of the transpose permutation. This is
+ // done by adding the indices not contained in the inner dimension position.
+ // For example if we have a source tensor of dimensions [0, 1, 2, 3]
+ // and inner dim position of [3, 0], the remaining indices are [1, 2].
+ // and the transpose will be [1, 2, 3, 0].
+ if (llvm::is_contained(innerDimPos, i))
+ continue;
srcPermForTranspose.push_back(i);
-
- srcPermForTranspose.append(SmallVector<int64_t>(packOp.getInnerDimsPos()));
+ }
+ srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"
<< "perm: " << llvm::interleaved(srcPermForTranspose)
diff --git a/mlir/test/Dialect/Linalg/decompose-pack.mlir b/mlir/test/Dialect/Linalg/decompose-pack.mlir
index 911b453f919c3..6d091406a639c 100644
--- a/mlir/test/Dialect/Linalg/decompose-pack.mlir
+++ b/mlir/test/Dialect/Linalg/decompose-pack.mlir
@@ -229,3 +229,22 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
// CHECK: return %[[INSERT]]
+
+// -----
+
+func.func @pack_with_unit_outer_dims_and_unit_inner(%arg0: tensor<1x1x4xf32>, %arg1: tensor<1x1x1x4x1xf32>) -> tensor<1x1x1x4x1xf32> {
+ %pack = linalg.pack %arg0 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg1 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32>
+ return %pack : tensor<1x1x1x4x1xf32>
+}
+
+// CHECK-LABEL: func.func @pack_with_unit_outer_dims_and_unit_inner
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x4x1xf32>
+// CHECK: %[[TRANSP:.+]] = linalg.transpose
+// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x4xf32>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x4x1xf32>)
+// CHECK-SAME: permutation = [1, 2, 0]
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
+// CHECK-SAME: [0, 0, 0, 0, 0] [1, 1, 1, 4, 1] [1, 1, 1, 1, 1] : tensor<1x4x1xf32> into tensor<1x1x1x4x1xf32>
+// CHECK: return %[[INSERT]]
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, you extend the pattern to handle the case that there are non ones in unpacked outer dimensions? I.e., should we relax the check in line 1165 - 1169? Then you are not fixing a corner case. Instead, you extend the support in general?
E.g., I'd expect the below test case working with your support, if I read your intention correctly. func.func @main(%arg0: tensor<2x1x1x4x1xf32>, %arg1: tensor<1x2x4xf32>) -> tensor<2x1x1x4x1xf32> {
%pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x2x4xf32> -> tensor<2x1x1x4x1xf32>
return %pack : tensor<2x1x1x4x1xf32>
} |
Thanks for the quick reply @hanhanW Not quite, at least for my use case I am still only concerned with unit outer dimensions in the unpacked case. AFAIK, the outer dimension in my case would be index My change is more about the adjacent trailing dimensions as @banach-space has now explained. I would be happy to extend 1165-1169 if anyone needs it. |
If it's not required, I would refrain from extending it right now. These "decomposition" patterns are already riddled with assumptions that we neither document nor test (like the case with non-adjacent dims that you discovered). Extending them could lead to even more un-verified assumptions. Btw, @chrsmcgrr , could you also the check Thanks! |
@banach-space @hanhanW I've updated the comments and removed the adjacent trailing dimensions check as it is no longer needed. This change will allow for that use-case. I have also added the corresponding test to the unpack version which works fine out-of-the-box. Looking at the unpack pattern I can't see a clean way of making the patterns symmetrical. So I will leave it for now. Let me know what you think. |
5c41b17
to
4f378a5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! Some minor comments inline.
4f378a5
to
79e0ff5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thank you, LGTM!
Given the following example: ``` module { func.func @main(%arg0: tensor<1x1x1x4x1xf32>, %arg1: tensor<1x1x4xf32>) -> tensor<1x1x1x4x1xf32> { %pack = linalg.pack %arg1 outer_dims_perm = [1, 2, 0] inner_dims_pos = [2, 0] inner_tiles = [4, 1] into %arg0 : tensor<1x1x4xf32> -> tensor<1x1x1x4x1xf32> return %pack : tensor<1x1x1x4x1xf32> } } ``` We would generate an invalid transpose operation because the calculated permutation would be `[0, 2, 0]` which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions. The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique. The above example would then translate to a transpose having a permutation of `[1, 2, 0]`. Following the rule, that the `inner_dim_pos` is appended to the permutation array and the preceding indices are filled with the remaining dimensions.
79e0ff5
to
af4d38d
Compare
I think all the comments are addressed, so we can land the PR? @chrsmcgrr let us know if you need any of us to help merge it. |
@hanhanW Yes if you could merge it that would be great :) |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/55/builds/13385 Here is the relevant piece of the build log for the reference
|
Given the following example:
We would generate an invalid transpose operation because the calculated permutation would be
[0, 2, 0]
which is semantically incorrect. As the permutation must contain unique integers corresponding to the source tensor dimensions.The following change modifies how we calculate the permutation array and ensures that the dimension indices given in the permutation array is unique.
The above example would then translate to a transpose having a permutation of
[1, 2, 0]
. Following the rule, that theinner_dim_pos
is appended to the permutation array and the preceding indices are filled with the remaining dimensions.