Skip to content

[MLIR][AArch64] Check indexing maps before checking for dimensions compatibility #145702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

In LowerContractionToSVEI8MMPattern check we have the expected indexing maps before deciding which operand dimension must match with which.

For example, with indexing map like:

lhs: (m, n, k) -> (m, k)
rhs: (m, n, k) -> (n, k)
acc: (m, n, k) -> (m, n)

we would like the second lhs dimension (columns) to match with the second rhs (rows, transposed) whereas with indexing maps like

lhs: (m, n, k) -> (m, k)
rhs: (m, n, k) -> (k, n)
acc: (m, n, k) -> (m, n)

we would like the second lhs dimension (columns) to match with the first rhs (rows, canonical matrix multiplication).

Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.

…mpatibility

In `LowerContractionToSVEI8MMPattern` check we have the expected
indexing maps before deciding which operand dimension must match
with which.

For example, with indexing map like:

    lhs: (m, n, k) -> (m, k)
    rhs: (m, n, k) -> (n, k)
    acc: (m, n, k) -> (m, n)

we would like the second `lhs` dimension (columns) to match with
the second `rhs` (rows, transposed) whereas with indexing maps like

    lhs: (m, n, k) -> (m, k)
    rhs: (m, n, k) -> (k, n)
    acc: (m, n, k) -> (m, n)

we would like the second `lhs` dimension (columns) to match with
the first `rhs` (rows, canonical matrix multiplication).

Since only the first kind of indexing maps is supported, the patch
does not change anything of significance, just the notification message
when the pattern would fail to apply anyway.
@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2025

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

In LowerContractionToSVEI8MMPattern check we have the expected indexing maps before deciding which operand dimension must match with which.

For example, with indexing map like:

lhs: (m, n, k) -> (m, k)
rhs: (m, n, k) -> (n, k)
acc: (m, n, k) -> (m, n)

we would like the second lhs dimension (columns) to match with the second rhs (rows, transposed) whereas with indexing maps like

lhs: (m, n, k) -> (m, k)
rhs: (m, n, k) -> (k, n)
acc: (m, n, k) -> (m, n)

we would like the second lhs dimension (columns) to match with the first rhs (rows, canonical matrix multiplication).

Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.


Full diff: https://github.com/llvm/llvm-project/pull/145702.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+19-19)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index a1209fe8230e2..70d2e06f48902 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -136,11 +136,26 @@ class LowerContractionToSVEI8MMPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
 
-    Location loc = op.getLoc();
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // This corresponds to matrix multiplication with transposed RHS.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
 
-    // Check the rank the types so we can safely examine their dimensions.
+    // Check the rank of the types so we can safely examine their dimensions.
     if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
@@ -159,22 +174,6 @@ class LowerContractionToSVEI8MMPattern
         !rhsType.getScalableDims()[0])
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
-    // Check permutation maps. For now only accept
-    //   lhs: (d0, d1, d2) -> (d0, d2)
-    //   rhs: (d0, d1, d2) -> (d1, d2)
-    //   acc: (d0, d1, d2) -> (d0, d1)
-    // This corresponds to matrix multiplication with transposed RHS.
-    if (op.getIndexingMapsArray()[0] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[1] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[2] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
-                                                 op.getContext()))
-      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
     // Check iterator types for matrix multiplication.
     auto itTypes = op.getIteratorTypesArray();
     if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
@@ -228,6 +227,7 @@ class LowerContractionToSVEI8MMPattern
                                    /*scalableDims=*/{true});
 
     // Extract LHS sub-tiles with logicall shape <2x8>.
+    Location loc = op.getLoc();
     SmallVector<Value> lhsTile;
     for (int64_t i = 0; i < M; i += 2) {
       // Extract two consecutive rows of the LHS tile.
@@ -283,7 +283,7 @@ class LowerContractionToSVEI8MMPattern
       if (mmlaOp == MMLA::MixedSwapped) {
         // We need to swap the positions of the LHS and RHS (since we don't have
         // a signed * unsigned operation), but then each individual 2x2 tile of
-        // the acumulator and (later) the result need to be transposed.
+        // the accumulator and (later) the result need to be transposed.
         accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
       } else {
         // Bitcast them to 64-bit elements, so subsequent

@llvmbot
Copy link
Member

llvmbot commented Jun 25, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

In LowerContractionToSVEI8MMPattern check we have the expected indexing maps before deciding which operand dimension must match with which.

For example, with indexing map like:

lhs: (m, n, k) -&gt; (m, k)
rhs: (m, n, k) -&gt; (n, k)
acc: (m, n, k) -&gt; (m, n)

we would like the second lhs dimension (columns) to match with the second rhs (rows, transposed) whereas with indexing maps like

lhs: (m, n, k) -&gt; (m, k)
rhs: (m, n, k) -&gt; (k, n)
acc: (m, n, k) -&gt; (m, n)

we would like the second lhs dimension (columns) to match with the first rhs (rows, canonical matrix multiplication).

Since only the first kind of indexing maps is supported, the patch does not change anything of significance, just the notification message when the pattern would fail to apply anyway.


Full diff: https://github.com/llvm/llvm-project/pull/145702.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+19-19)
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
index a1209fe8230e2..70d2e06f48902 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -136,11 +136,26 @@ class LowerContractionToSVEI8MMPattern
   LogicalResult matchAndRewrite(vector::ContractionOp op,
                                 PatternRewriter &rewriter) const override {
 
-    Location loc = op.getLoc();
+    // Check permutation maps. For now only accept
+    //   lhs: (d0, d1, d2) -> (d0, d2)
+    //   rhs: (d0, d1, d2) -> (d1, d2)
+    //   acc: (d0, d1, d2) -> (d0, d1)
+    // This corresponds to matrix multiplication with transposed RHS.
+    if (op.getIndexingMapsArray()[0] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[1] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+                                                 op.getContext()) ||
+        op.getIndexingMapsArray()[2] !=
+            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+                                                 op.getContext()))
+      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
+
     mlir::VectorType lhsType = op.getLhsType();
     mlir::VectorType rhsType = op.getRhsType();
 
-    // Check the rank the types so we can safely examine their dimensions.
+    // Check the rank of the types so we can safely examine their dimensions.
     if (lhsType.getRank() != 2 || rhsType.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
@@ -159,22 +174,6 @@ class LowerContractionToSVEI8MMPattern
         !rhsType.getScalableDims()[0])
       return rewriter.notifyMatchFailure(op, "non-matching operand shape");
 
-    // Check permutation maps. For now only accept
-    //   lhs: (d0, d1, d2) -> (d0, d2)
-    //   rhs: (d0, d1, d2) -> (d1, d2)
-    //   acc: (d0, d1, d2) -> (d0, d1)
-    // This corresponds to matrix multiplication with transposed RHS.
-    if (op.getIndexingMapsArray()[0] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[1] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
-                                                 op.getContext()) ||
-        op.getIndexingMapsArray()[2] !=
-            AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
-                                                 op.getContext()))
-      return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
-
     // Check iterator types for matrix multiplication.
     auto itTypes = op.getIteratorTypesArray();
     if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
@@ -228,6 +227,7 @@ class LowerContractionToSVEI8MMPattern
                                    /*scalableDims=*/{true});
 
     // Extract LHS sub-tiles with logicall shape <2x8>.
+    Location loc = op.getLoc();
     SmallVector<Value> lhsTile;
     for (int64_t i = 0; i < M; i += 2) {
       // Extract two consecutive rows of the LHS tile.
@@ -283,7 +283,7 @@ class LowerContractionToSVEI8MMPattern
       if (mmlaOp == MMLA::MixedSwapped) {
         // We need to swap the positions of the LHS and RHS (since we don't have
         // a signed * unsigned operation), but then each individual 2x2 tile of
-        // the acumulator and (later) the result need to be transposed.
+        // the accumulator and (later) the result need to be transposed.
         accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1);
       } else {
         // Bitcast them to 64-bit elements, so subsequent

Comment on lines 139 to 153
// Check permutation maps. For now only accept
// lhs: (d0, d1, d2) -> (d0, d2)
// rhs: (d0, d1, d2) -> (d1, d2)
// acc: (d0, d1, d2) -> (d0, d1)
// This corresponds to matrix multiplication with transposed RHS.
if (op.getIndexingMapsArray()[0] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
op.getContext()) ||
op.getIndexingMapsArray()[1] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
op.getContext()) ||
op.getIndexingMapsArray()[2] !=
AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
op.getContext()))
return rewriter.notifyMatchFailure(op, "non-matching permutation maps");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't these permutation maps imply the dimensionality of lhs and rhs? It feels like there's no need to check the rank below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, indeed, done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants