Skip to content

[mlir][Vector] Add constant folding for vector.from_elements operation #145849

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

yangtetris
Copy link
Contributor

Summary

This PR adds a new folding pattern for vector.from_elements that canonicalizes it to arith.constant when all input operands are constants.

Implementation Details

Leverages FoldAdaptor capabilities: Uses adaptor.getElements() to access pre-computed constant attributes, avoiding redundant pattern matching on operands.

Example Transformation

Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Yang Bai (yangtetris)

Changes

Summary

This PR adds a new folding pattern for vector.from_elements that canonicalizes it to arith.constant when all input operands are constants.

Implementation Details

Leverages FoldAdaptor capabilities: Uses adaptor.getElements() to access pre-computed constant attributes, avoiding redundant pattern matching on operands.

Example Transformation

Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector&lt;2x2xi32&gt;

After:
%v = arith.constant dense&lt;[[0, 1], [2, 3]]&gt; : vector&lt;2x2xi32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/145849.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+23-1)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+14)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 862ed7bae1fbb..9afb443cebc13 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2459,8 +2459,30 @@ static OpFoldResult foldFromElementsToElements(FromElementsOp fromElementsOp) {
   return {};
 }
 
+/// Fold vector.from_elements to a constant when all operands are constants.
+/// Example:
+///   %c1 = arith.constant 1 : i32
+///   %c2 = arith.constant 2 : i32
+///   %v = vector.from_elements %c1, %c2 : vector<2xi32>
+/// =>
+///   %v = arith.constant dense<[1, 2]> : vector<2xi32>
+///
+static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp,
+                                               ArrayRef<Attribute> elements) {
+  if (llvm::any_of(elements, [](Attribute attr) { return !attr; }))
+    return {};
+
+  auto destType = cast<VectorType>(fromElementsOp.getType());
+  return DenseElementsAttr::get(destType, elements);
+}
+
 OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
-  return foldFromElementsToElements(*this);
+  if (auto res = foldFromElementsToElements(*this))
+    return res;
+  if (auto res = foldFromElementsToConstant(*this, adaptor.getElements()))
+    return res;
+
+  return {};
 }
 
 /// Rewrite a vector.from_elements into a vector.splat if all elements are the
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 65b73375831da..d56c64552f9e7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3075,6 +3075,20 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @from_elements_to_constant
+func.func @from_elements_to_constant() -> vector<2x2xi32> {
+  %c0_i32 = arith.constant 0 : i32
+  %c1_i32 = arith.constant 1 : i32
+  %c2_i32 = arith.constant 2 : i32
+  %c3_i32 = arith.constant 3 : i32
+  // CHECK: %[[RES:.*]] = arith.constant dense<{{\[\[0, 1\], \[2, 3\]\]}}> : vector<2x2xi32>
+  %res = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>
+  // CHECK: return %[[RES]]
+  return %res : vector<2x2xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @vector_insert_const_regression(
 //       CHECK:   llvm.mlir.undef
 //       CHECK:   vector.insert

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nicely implemented, well done! LGTM.

@yangtetris
Copy link
Contributor Author

I found that the following MLIR code causes a crash:

%c1 = llvm.mlir.constant(1 : index) : i64
%c2 = llvm.mlir.constant(2 : index) : i64
%res = vector.from_elements %c1, %c2 : vector<2xi64>

The error is:
llvm-project/mlir/lib/IR/BuiltinAttributes.cpp:978: static mlir::DenseElementsAttr mlir::DenseElementsAttr::get(mlir::ShapedType, llvm::ArrayRefmlir::Attribute): Assertion `intAttr.getType() == eltType && "expected integer attribute type to equal element type"' failed.

It seems that this issue is the same as this one. So I committed a change that applies convertIntegerAttr before creating the dense elements attribute.

@Groverkss could you please take another look? Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG! Dropped a few comments, thanks!

}

// -----

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it possible to add a test that crashes without using llvm.mlir.constant ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of difficult for me as I don't know which ops with the constant trait(besides llvm.mlir.constant) allow mismatched attribute and return types. And, I don't think it's worth introducing a new custom op for this test :(

@yangtetris
Copy link
Contributor Author

@dcaballe @Groverkss Please help to land it when you are free. Thanks!

@dcaballe dcaballe merged commit 393a75e into llvm:main Jul 1, 2025
7 checks passed
@yangtetris yangtetris deleted the fold_vector_from_elements_to_constant branch July 1, 2025 03:41
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
llvm#145849)

### Summary

This PR adds a new folding pattern for **vector.from_elements** that
canonicalizes it to **arith.constant** when all input operands are
constants.

### Implementation Details

**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to
access **pre-computed** constant attributes, avoiding redundant pattern
matching on operands.

### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```

---------

Co-authored-by: Yang Bai <yangb@nvidia.com>
rlavaee pushed a commit to rlavaee/llvm-project that referenced this pull request Jul 1, 2025
llvm#145849)

### Summary

This PR adds a new folding pattern for **vector.from_elements** that
canonicalizes it to **arith.constant** when all input operands are
constants.

### Implementation Details

**Leverages FoldAdaptor capabilities**: Uses adaptor.getElements() to
access **pre-computed** constant attributes, avoiding redundant pattern
matching on operands.

### Example Transformation
```
Before:
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%v = vector.from_elements %c0_i32, %c1_i32, %c2_i32, %c3_i32 : vector<2x2xi32>

After:
%v = arith.constant dense<[[0, 1], [2, 3]]> : vector<2x2xi32>
```

---------

Co-authored-by: Yang Bai <yangb@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants