-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[MLIR][AArch64] Check indexing maps before checking for dimensions compatibility #145702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…mpatibility In `LowerContractionToSVEI8MMPattern` check we have the expected indexing maps before deciding which operand dimension must match with which. For example, with indexing map like: lhs: (m, n, k) -> (m, k) rhs: (m, n, k) -> (n, k) acc: (m, n, k) -> (m, n) we would like the second `lhs` dimension (columns) to match with the second `rhs` (rows, transposed) whereas with indexing maps like lhs: (m, n, k) -> (m, k) rhs: (m, n, k) -> (k, n) acc: (m, n, k) -> (m, n) we would like the second `lhs` dimension (columns) to match with the first `rhs` (rows, canonical matrix multiplication). Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.
@llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesIn For example, with indexing map like:
we would like the second
we would like the second Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway. Full diff: https://github.com/llvm/llvm-project/pull/145702.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index a1209fe8230e2..70d2e06f48902 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -136,11 +136,26 @@ class LowerContractionToSVEI8MMPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- // Check the rank the types so we can safely examine their dimensions.
+ // Check the rank of the types so we can safely examine their dimensions.
if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
@@ -159,22 +174,6 @@ class LowerContractionToSVEI8MMPattern
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
- // Check permutation maps. For now only accept
- // lhs: (d0, d1, d2) -> (d0, d2)
- // rhs: (d0, d1, d2) -> (d1, d2)
- // acc: (d0, d1, d2) -> (d0, d1)
- // This corresponds to matrix multiplication with transposed RHS.
- if (op.getIndexingMapsArray()[0] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[1] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[2] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
- op.getContext()))
- return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
// Check iterator types for matrix multiplication.
auto itTypes = op.getIteratorTypesArray();
if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
@@ -228,6 +227,7 @@ class LowerContractionToSVEI8MMPattern
/*scalableDims=*/{true});
// Extract LHS sub-tiles with logicall shape <2x8>.
+ Location loc = op.getLoc();
SmallVector<Value> lhsTile;
for (int64_t i = 0; i < M; i += 2) {
// Extract two consecutive rows of the LHS tile.
@@ -283,7 +283,7 @@ class LowerContractionToSVEI8MMPattern
if (mmlaOp == MMLA::MixedSwapped) {
// We need to swap the positions of the LHS and RHS (since we don't have
// a signed * unsigned operation), but then each individual 2x2 tile of
- // the acumulator and (later) the result need to be transposed.
+ // the accumulator and (later) the result need to be transposed.
accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
} else {
// Bitcast them to 64-bit elements, so subsequent
|
@llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesIn For example, with indexing map like:
we would like the second
we would like the second Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway. Full diff: https://github.com/llvm/llvm-project/pull/145702.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index a1209fe8230e2..70d2e06f48902 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -136,11 +136,26 @@ class LowerContractionToSVEI8MMPattern
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
- Location loc = op.getLoc();
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // This corresponds to matrix multiplication with transposed RHS.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- // Check the rank the types so we can safely examine their dimensions.
+ // Check the rank of the types so we can safely examine their dimensions.
if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
@@ -159,22 +174,6 @@ class LowerContractionToSVEI8MMPattern
!rhsType.getScalableDims()[0])
return rewriter.notifyMatchFailure(op, "non-matching operand shape");
- // Check permutation maps. For now only accept
- // lhs: (d0, d1, d2) -> (d0, d2)
- // rhs: (d0, d1, d2) -> (d1, d2)
- // acc: (d0, d1, d2) -> (d0, d1)
- // This corresponds to matrix multiplication with transposed RHS.
- if (op.getIndexingMapsArray()[0] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[1] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
- op.getContext()) ||
- op.getIndexingMapsArray()[2] !=
- AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
- op.getContext()))
- return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
// Check iterator types for matrix multiplication.
auto itTypes = op.getIteratorTypesArray();
if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
@@ -228,6 +227,7 @@ class LowerContractionToSVEI8MMPattern
/*scalableDims=*/{true});
// Extract LHS sub-tiles with logicall shape <2x8>.
+ Location loc = op.getLoc();
SmallVector<Value> lhsTile;
for (int64_t i = 0; i < M; i += 2) {
// Extract two consecutive rows of the LHS tile.
@@ -283,7 +283,7 @@ class LowerContractionToSVEI8MMPattern
if (mmlaOp == MMLA::MixedSwapped) {
// We need to swap the positions of the LHS and RHS (since we don't have
// a signed * unsigned operation), but then each individual 2x2 tile of
- // the acumulator and (later) the result need to be transposed.
+ // the accumulator and (later) the result need to be transposed.
accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
} else {
// Bitcast them to 64-bit elements, so subsequent
|
// Check permutation maps. For now only accept | ||
// lhs: (d0, d1, d2) -> (d0, d2) | ||
// rhs: (d0, d1, d2) -> (d1, d2) | ||
// acc: (d0, d1, d2) -> (d0, d1) | ||
// This corresponds to matrix multiplication with transposed RHS. | ||
if (op.getIndexingMapsArray()[0] != | ||
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u}, | ||
op.getContext()) || | ||
op.getIndexingMapsArray()[1] != | ||
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u}, | ||
op.getContext()) || | ||
op.getIndexingMapsArray()[2] != | ||
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u}, | ||
op.getContext())) | ||
return rewriter.notifyMatchFailure(op, "non-matching permutation maps"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't these permutation maps imply the dimensionality of lhs
and rhs
? It feels like there's no need to check the rank below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, indeed, done.
In
LowerContractionToSVEI8MMPattern
check we have the expected indexing maps before deciding which operand dimension must match with which.For example, with indexing map like:
we would like the second
lhs
dimension (columns) to match with the secondrhs
(rows, transposed) whereas with indexing maps likewe would like the second
lhs
dimension (columns) to match with the firstrhs
(rows, canonical matrix multiplication).Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.