Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConstantFold] Support scalable constant splats in ConstantFoldCastInstruction #133207

Merged

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Mar 27, 2025

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

By allowing ConstantExprs this also allow fixed vector ConstantExprs to be folded, which causes the diffs in llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll and llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll. I can remove them from this PR if reviewers would prefer.

Fixes #132922

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Luke Lau (lukel97)

Changes

Stacked on #132960 to prevent a regression

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine.

Fixes #132922


Full diff: https://github.com/llvm/llvm-project/pull/133207.diff

5 Files Affected:

  • (modified) llvm/lib/IR/ConstantFold.cpp (+6-4)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp (+6-6)
  • (modified) llvm/test/Transforms/InstCombine/fpextend.ll (+11)
  • (modified) llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll (+12)
  • (modified) llvm/test/Transforms/InstCombine/scalable-trunc.ll (+15)
diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp
index b577f69eeaba0..692d15546e70e 100644
--- a/llvm/lib/IR/ConstantFold.cpp
+++ b/llvm/lib/IR/ConstantFold.cpp
@@ -160,10 +160,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
   // If the cast operand is a constant vector, perform the cast by
   // operating on each element. In the cast of bitcasts, the element
   // count may be mismatched; don't attempt to handle that here.
-  if ((isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) &&
-      DestTy->isVectorTy() &&
-      cast<FixedVectorType>(DestTy)->getNumElements() ==
-          cast<FixedVectorType>(V->getType())->getNumElements()) {
+  if ((isa<ConstantVector, ConstantDataVector, ConstantExpr>(V)) &&
+      DestTy->isVectorTy() && V->getType()->isVectorTy() &&
+      cast<VectorType>(DestTy)->getElementCount() ==
+          cast<VectorType>(V->getType())->getElementCount()) {
     VectorType *DestVecTy = cast<VectorType>(DestTy);
     Type *DstEltTy = DestVecTy->getElementType();
     // Fast path for splatted constants.
@@ -174,6 +174,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       return ConstantVector::getSplat(
           cast<VectorType>(DestTy)->getElementCount(), Res);
     }
+    if (isa<ScalableVectorType>(DestTy))
+      return nullptr;
     SmallVector<Constant *, 16> res;
     Type *Ty = IntegerType::get(V->getContext(), 32);
     for (unsigned i = 0,
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4ec1af394464b..1a95636f37ed7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1684,12 +1684,12 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
     if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
       return T;
 
-  // We can only correctly find a minimum type for a scalable vector when it is
-  // a splat. For splats of constant values the fpext is wrapped up as a
-  // ConstantExpr.
-  if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
-    if (FPCExt->getOpcode() == Instruction::FPExt)
-      return FPCExt->getOperand(0)->getType();
+  // Try to shrink scalable and fixed splat vectors.
+  if (auto *FPC = dyn_cast<Constant>(V))
+    if (isa<VectorType>(V->getType()))
+      if (auto *Splat = dyn_cast_or_null<ConstantFP>(FPC->getSplatValue()))
+        if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
+          return T;
 
   // Try to shrink a vector of FP constants. This returns nullptr on scalable
   // vectors
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db4..9125339c00ecf 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,14 @@ define bfloat @bf16_frem(bfloat %x) {
   %t3 = fptrunc float %t2 to bfloat
   ret bfloat %t3
 }
+
+define <4 x float> @v4f32_fadd(<4 x float> %a) {
+; CHECK-LABEL: @v4f32_fadd(
+; CHECK-NEXT:    [[TMP1:%.*]] = fadd <4 x float> [[A:%.*]], splat (float -1.000000e+00)
+; CHECK-NEXT:    ret <4 x float> [[TMP1]]
+;
+  %2 = fpext <4 x float> %a to <4 x double>
+  %4 = fadd <4 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <4 x double> %4 to <4 x float>
+  ret <4 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 731b079881f08..0982ecfbd3ea3 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -13,3 +13,15 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
   %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
   ret <vscale x 2 x float> %5
 }
+
+define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
+; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
+; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
+; CHECK-NEXT:    [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], splat (float -1.000000e+00)
+; CHECK-NEXT:    ret <vscale x 2 x float> [[TMP3]]
+;
+  %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+  %4 = fadd <vscale x 2 x double> %2, splat (double -1.000000e+00)
+  %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
+  ret <vscale x 2 x float> %5
+}
diff --git a/llvm/test/Transforms/InstCombine/scalable-trunc.ll b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
index dcf4abe10425b..6272ccfe9cdbd 100644
--- a/llvm/test/Transforms/InstCombine/scalable-trunc.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-trunc.ll
@@ -20,6 +20,21 @@ entry:
   ret void
 }
 
+define <vscale x 1 x i8> @constant_splat_trunc() {
+; CHECK-LABEL: @constant_splat_trunc(
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
+;
+  %1 = trunc <vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>
+  ret <vscale x 1 x i8> %1
+}
+
+define <vscale x 1 x i8> @constant_splat_trunc_constantexpr() {
+; CHECK-LABEL: @constant_splat_trunc_constantexpr(
+; CHECK-NEXT:    ret <vscale x 1 x i8> splat (i8 1)
+;
+  ret <vscale x 1 x i8> trunc (<vscale x 1 x i64> splat (i64 1) to <vscale x 1 x i8>)
+}
+
 declare void @llvm.aarch64.sve.st1.nxv2i32(<vscale x 2 x i32>, <vscale x 2 x i1>, ptr)
 declare <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1>)
 declare <vscale x 16 x i1> @llvm.aarch64.sve.ptrue.nxv16i1(i32 %pattern)

@lukel97 lukel97 force-pushed the instcombine/constant-fold-scalable-splat-cast branch from 41f9586 to 687f0ea Compare March 27, 2025 17:02
@lukel97
Copy link
Contributor Author

lukel97 commented Mar 27, 2025

Rebased now that #132960 is landed, and I've also updated a bunch of test cases outside of InstCombine that I forgot to check, sorry!

lukel97 added 3 commits March 28, 2025 09:36
…struction

Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats.

We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast.

By allowing ConstantExprs this also allow fixed vector ConstantExprs to be folded, which causes the diffs in llvm/test/Analysis/ValueTracking/known-bits-from-operator-constexpr.ll and llvm/test/Transforms/InstSimplify/ConstProp/cast-vector.ll. I can remove them from this PR if reviewers would prefer.

Fixes llvm#132922
@lukel97 lukel97 force-pushed the instcombine/constant-fold-scalable-splat-cast branch from 85a04e2 to ccb0eda Compare March 28, 2025 09:45
Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lukel97 lukel97 merged commit 79435de into llvm:main Apr 3, 2025
11 checks passed
lukel97 added a commit that referenced this pull request Apr 3, 2025
…" (#134262)

This reapplies #132522.

Previously casts of scalable m_ImmConstant splats weren't being folded
by ConstantFoldCastOperand, triggering the "Constant-fold of ImmConstant
should not fail" assertion.

There are no changes to the code in this PR, instead we just needed
#133207 to land first.

A test has been added for the assertion in
llvm/test/Transforms/InstSimplify/vec-icmp-of-cast.ll
@icmp_ult_sext_scalable_splat_is_true.

<hr/>

#118806 fixed an infinite loop in FoldShiftByConstant that could occur
when the shift amount was a ConstantExpr.

However this meant that FoldShiftByConstant no longer kicked in for
scalable vectors because scalable splats are represented by
ConstantExprs.

This fixes it by allowing scalable splats of non-ConstantExprs in
m_ImmConstant, which also fixes a few other test cases where scalable
splats were being missed.

But I'm also hoping that UseConstantIntForScalableSplat will eventually
remove the need for this.

I noticed this when trying to reverse a combine on RISC-V in #132245,
and saw that the resulting vector and scalar forms were different.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[InstCombine] Scalable splats don't have trunc constant folded
3 participants