Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MergeFunc] Fix crash caused by bitcasting ArrayType #133259

Merged
merged 3 commits into from
Apr 4, 2025

Conversation

tobias-stadler
Copy link
Contributor

createCast in MergeFunctions did not consider ArrayTypes, which results in the creation of a bitcast between ArrayTypes in the thunk function, leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle ArrayTypes, so this common code has been factored out into the IRBuilder.

@llvmbot
Copy link
Member

llvmbot commented Mar 27, 2025

@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-llvm-transforms

Author: Tobias Stadler (tobias-stadler)

Changes

createCast in MergeFunctions did not consider ArrayTypes, which results in the creation of a bitcast between ArrayTypes in the thunk function, leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle ArrayTypes, so this common code has been factored out into the IRBuilder.


Full diff: https://github.com/llvm/llvm-project/pull/133259.diff

5 Files Affected:

  • (modified) llvm/include/llvm/IR/IRBuilder.h (+7)
  • (modified) llvm/lib/CodeGen/GlobalMergeFunctions.cpp (+5-42)
  • (modified) llvm/lib/IR/IRBuilder.cpp (+35)
  • (modified) llvm/lib/Transforms/IPO/MergeFunctions.cpp (+2-29)
  • (added) llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll (+38)
diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h
index 750a99cc50dd7..a20fdec3f201d 100644
--- a/llvm/include/llvm/IR/IRBuilder.h
+++ b/llvm/include/llvm/IR/IRBuilder.h
@@ -2291,6 +2291,13 @@ class IRBuilderBase {
   // isSigned parameter.
   Value *CreateIntCast(Value *, Type *, const char *) = delete;
 
+  /// Cast between aggregate types that must have identical structure but may
+  /// differ in their leaf types. The leaf values are recursively extracted,
+  /// casted, and then reinserted into a value of type DestTy. The leaf types
+  /// must be castable using a bitcast or ptrcast, because signedness is
+  /// not specified.
+  Value *CreateAggregateCast(Value *V, Type *DestTy);
+
   //===--------------------------------------------------------------------===//
   // Instruction creation methods: Compare Instructions
   //===--------------------------------------------------------------------===//
diff --git a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
index e920b1be6822c..d4c53e79ed2e1 100644
--- a/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
+++ b/llvm/lib/CodeGen/GlobalMergeFunctions.cpp
@@ -140,44 +140,6 @@ static bool ignoreOp(const Instruction *I, unsigned OpIdx) {
   return true;
 }
 
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
-    auto *DestAT = dyn_cast<ArrayType>(DestTy);
-    assert(DestAT);
-    assert(SrcAT->getNumElements() == DestAT->getNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestAT->getElementType());
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isArrayTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  return Builder.CreateBitCast(V, DestTy);
-}
-
 void GlobalMergeFunc::analyze(Module &M) {
   ++NumAnalyzedModues;
   for (Function &Func : M) {
@@ -268,7 +230,7 @@ static Function *createMergedFunction(FuncMergeInfo &FI,
       if (OrigC->getType() != NewArg->getType()) {
         IRBuilder<> Builder(Inst->getParent(), Inst->getIterator());
         Inst->setOperand(OpndIndex,
-                         createCast(Builder, NewArg, OrigC->getType()));
+                         Builder.CreateAggregateCast(NewArg, OrigC->getType()));
       } else {
         Inst->setOperand(OpndIndex, NewArg);
       }
@@ -297,7 +259,8 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
 
   // Add arguments which are passed through Thunk.
   for (Argument &AI : Thunk->args()) {
-    Args.push_back(createCast(Builder, &AI, ToFuncTy->getParamType(ParamIdx)));
+    Args.push_back(
+        Builder.CreateAggregateCast(&AI, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -305,7 +268,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   for (auto *Param : Params) {
     assert(ParamIdx < ToFuncTy->getNumParams());
     Args.push_back(
-        createCast(Builder, Param, ToFuncTy->getParamType(ParamIdx)));
+        Builder.CreateAggregateCast(Param, ToFuncTy->getParamType(ParamIdx)));
     ++ParamIdx;
   }
 
@@ -319,7 +282,7 @@ static void createThunk(FuncMergeInfo &FI, ArrayRef<Constant *> Params,
   if (Thunk->getReturnType()->isVoidTy())
     Builder.CreateRetVoid();
   else
-    Builder.CreateRet(createCast(Builder, CI, Thunk->getReturnType()));
+    Builder.CreateRet(Builder.CreateAggregateCast(CI, Thunk->getReturnType()));
 }
 
 // Check if the old merged/optimized IndexOperandHashMap is compatible with
diff --git a/llvm/lib/IR/IRBuilder.cpp b/llvm/lib/IR/IRBuilder.cpp
index 421b617a5fb7e..58a65ec646557 100644
--- a/llvm/lib/IR/IRBuilder.cpp
+++ b/llvm/lib/IR/IRBuilder.cpp
@@ -76,6 +76,41 @@ void IRBuilderBase::SetInstDebugLocation(Instruction *I) const {
     }
 }
 
+Value *IRBuilderBase::CreateAggregateCast(Value *V, Type *DestTy) {
+  Type *SrcTy = V->getType();
+  if (SrcTy == DestTy)
+    return V;
+  if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
+    assert(DestTy->isStructTy() && "Expected StructType");
+    auto *DestST = cast<StructType>(DestTy);
+    assert(SrcST->getNumElements() == DestST->getNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcST->getNumElements(); I < E; ++I) {
+      Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+                                           DestST->getElementType(I));
+
+      Result = CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+  if (auto *SrcAT = dyn_cast<ArrayType>(SrcTy)) {
+    assert(DestTy->isArrayTy() && "Expected ArrayType");
+    auto *DestAT = cast<ArrayType>(DestTy);
+    assert(SrcAT->getNumElements() == DestAT->getNumElements());
+    Value *Result = PoisonValue::get(DestTy);
+    for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
+      Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
+                                           DestAT->getElementType());
+
+      Result = CreateInsertValue(Result, Element, ArrayRef(I));
+    }
+    return Result;
+  }
+
+  assert(!DestTy->isAggregateType());
+  return CreateBitOrPointerCast(V, DestTy);
+}
+
 CallInst *
 IRBuilderBase::createCallHelper(Function *Callee, ArrayRef<Value *> Ops,
                                 const Twine &Name, FMFSource FMFSource,
diff --git a/llvm/lib/Transforms/IPO/MergeFunctions.cpp b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
index 924db314674d5..c58c0f40c1b23 100644
--- a/llvm/lib/Transforms/IPO/MergeFunctions.cpp
+++ b/llvm/lib/Transforms/IPO/MergeFunctions.cpp
@@ -511,33 +511,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) {
   }
 }
 
-// Helper for writeThunk,
-// Selects proper bitcast operation,
-// but a bit simpler then CastInst::getCastOpcode.
-static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) {
-  Type *SrcTy = V->getType();
-  if (SrcTy->isStructTy()) {
-    assert(DestTy->isStructTy());
-    assert(SrcTy->getStructNumElements() == DestTy->getStructNumElements());
-    Value *Result = PoisonValue::get(DestTy);
-    for (unsigned int I = 0, E = SrcTy->getStructNumElements(); I < E; ++I) {
-      Value *Element =
-          createCast(Builder, Builder.CreateExtractValue(V, ArrayRef(I)),
-                     DestTy->getStructElementType(I));
-
-      Result = Builder.CreateInsertValue(Result, Element, ArrayRef(I));
-    }
-    return Result;
-  }
-  assert(!DestTy->isStructTy());
-  if (SrcTy->isIntegerTy() && DestTy->isPointerTy())
-    return Builder.CreateIntToPtr(V, DestTy);
-  else if (SrcTy->isPointerTy() && DestTy->isIntegerTy())
-    return Builder.CreatePtrToInt(V, DestTy);
-  else
-    return Builder.CreateBitCast(V, DestTy);
-}
-
 // Erase the instructions in PDIUnrelatedWL as they are unrelated to the
 // parameter debug info, from the entry block.
 void MergeFunctions::eraseInstsUnrelatedToPDI(
@@ -789,7 +762,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   unsigned i = 0;
   FunctionType *FFTy = F->getFunctionType();
   for (Argument &AI : H->args()) {
-    Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i)));
+    Args.push_back(Builder.CreateAggregateCast(&AI, FFTy->getParamType(i)));
     ++i;
   }
 
@@ -804,7 +777,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) {
   if (H->getReturnType()->isVoidTy()) {
     RI = Builder.CreateRetVoid();
   } else {
-    RI = Builder.CreateRet(createCast(Builder, CI, H->getReturnType()));
+    RI = Builder.CreateRet(Builder.CreateAggregateCast(CI, H->getReturnType()));
   }
 
   if (MergeFunctionsPDI) {
diff --git a/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
new file mode 100644
index 0000000000000..fcbb06400a618
--- /dev/null
+++ b/llvm/test/Transforms/MergeFunc/crash-cast-arrays.ll
@@ -0,0 +1,38 @@
+; RUN: opt -S -passes=mergefunc < %s | FileCheck %s
+
+%A = type { double }
+; the intermediary struct causes A_arr and B_arr to be different types
+%A_struct = type { %A }
+%A_arr = type { [1 x %A_struct] }
+
+%B = type { double }
+%B_struct = type { %B }
+%B_arr = type { [1 x %B_struct] }
+
+declare void @noop()
+
+define %A_arr @a() {
+; CHECK-LABEL: define %A_arr @a() {
+; CHECK-NEXT:    call void @noop()
+; CHECK-NEXT:    ret %A_arr zeroinitializer
+;
+  call void @noop()
+  ret %A_arr zeroinitializer
+}
+
+define %B_arr @b() {
+; CHECK-LABEL: define %B_arr @b() {
+; CHECK-NEXT:    [[TMP1:%.*]] = tail call %A_arr @a
+; CHECK-NEXT:    [[TMP2:%.*]] = extractvalue %A_arr [[TMP1]], 0
+; CHECK-NEXT:    [[TMP3:%.*]] = extractvalue [1 x %A_struct] [[TMP2]], 0
+; CHECK-NEXT:    [[TMP4:%.*]] = extractvalue %A_struct [[TMP3]], 0
+; CHECK-NEXT:    [[TMP5:%.*]] = extractvalue %A [[TMP4]], 0
+; CHECK-NEXT:    [[TMP6:%.*]] = insertvalue %B poison, double [[TMP5]], 0
+; CHECK-NEXT:    [[TMP7:%.*]] = insertvalue %B_struct poison, %B [[TMP6]], 0
+; CHECK-NEXT:    [[TMP8:%.*]] = insertvalue [1 x %B_struct] poison, %B_struct [[TMP7]], 0
+; CHECK-NEXT:    [[TMP9:%.*]] = insertvalue %B_arr poison, [1 x %B_struct] [[TMP8]], 0
+; CHECK-NEXT:    ret %B_arr [[TMP9]]
+;
+  call void @noop()
+  ret %B_arr zeroinitializer
+}

if (auto *SrcST = dyn_cast<StructType>(SrcTy)) {
assert(DestTy->isStructTy() && "Expected StructType");
auto *DestST = cast<StructType>(DestTy);
assert(SrcST->getNumElements() == DestST->getNumElements());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add message to assert

Comment on lines 97 to 105
assert(DestTy->isArrayTy() && "Expected ArrayType");
auto *DestAT = cast<ArrayType>(DestTy);
assert(SrcAT->getNumElements() == DestAT->getNumElements());
Value *Result = PoisonValue::get(DestTy);
for (unsigned int I = 0, E = SrcAT->getNumElements(); I < E; ++I) {
Value *Element = CreateAggregateCast(CreateExtractValue(V, ArrayRef(I)),
DestAT->getElementType());

Result = CreateInsertValue(Result, Element, ArrayRef(I));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this effectively the same as for the structure case, except for the cast/assert for the types? If so, would be good if we could unify the code.

@@ -0,0 +1,38 @@
; RUN: opt -S -passes=mergefunc < %s | FileCheck %s

%A = type { double }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test that requires a bitcast e.g. where one leaf type is double and the other i64? Not sure if such types would get handled by mergefunc though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added ptrcast test instead of bitcast. I checked the code in FunctionComparator and searched the tests, and couldn't find a type that causes a bitcast in the thunk function.

Copy link
Contributor

@fhahn fhahn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

tobias-stadler and others added 3 commits April 3, 2025 16:38
createCast in MergeFunctions did not consider ArrayTypes, which results
in the creation of a bitcast between ArrayTypes in the thunk function,
leading to an assertion failure in the provided test case.

The version of createCast in GlobalMergeFunctions does handle
ArrayTypes, so this common code has been factored out into the
IRBuilder.
Co-authored-by: Florian Hahn <flo@fhahn.com>
@tobias-stadler tobias-stadler force-pushed the fix-mergefunc-crash-cast branch from a577553 to dcf4911 Compare April 3, 2025 16:29
@tobias-stadler tobias-stadler merged commit 1302610 into llvm:main Apr 4, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants