Skip to content

Commit

Permalink
[LLVM] Replace calls to Type::getVectorNumElements (apache#5398)
Browse files Browse the repository at this point in the history
This function has recently been removed from LLVM 11. Use alternative
way to obtain vector element count (VectorType::getNumElements) which
works for all LLVM versions.
  • Loading branch information
Krzysztof Parzyszek authored and Trevor Morris committed Jun 8, 2020
1 parent 507b73e commit 5f4f523
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
19 changes: 9 additions & 10 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ llvm::Value* CodeGenLLVM::CreateBroadcast(llvm::Value* value, int lanes) {
}

llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
if (extent == num_elems && begin == 0) return vec;
CHECK(begin >= 0 && extent <= num_elems) << "Slicing out of bound!\n";
std::vector<llvm::Constant*> indices;
Expand All @@ -490,7 +490,7 @@ llvm::Value* CodeGenLLVM::CreateVecSlice(llvm::Value* vec, int begin, int extent
}

llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand All @@ -505,7 +505,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) {
llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) {
llvm::Value* mask = llvm::UndefValue::get(
DTypeToLLVMType(DataType::Int(32, target_lanes)));
int num_elems = static_cast<int>(vec->getType()->getVectorNumElements());
int num_elems = llvm::cast<llvm::VectorType>(vec->getType())->getNumElements();
if (num_elems == target_lanes) return vec;
CHECK_LT(num_elems, target_lanes);
for (int i = 0; i < num_elems; ++i) {
Expand All @@ -519,16 +519,15 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector<llvm::Value*> vecs) {
int total_lanes = 0;

for (llvm::Value* v : vecs) {
total_lanes += static_cast<int>(
v->getType()->getVectorNumElements());
total_lanes += llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
}
while (vecs.size() > 1) {
std::vector<llvm::Value*> new_vecs;
for (size_t i = 0; i < vecs.size() - 1; i += 2) {
llvm::Value* lhs = vecs[i];
llvm::Value* rhs = vecs[i + 1];
const size_t lhs_lanes = lhs->getType()->getVectorNumElements();
const size_t rhs_lanes = rhs->getType()->getVectorNumElements();
const size_t lhs_lanes = llvm::cast<llvm::VectorType>(lhs->getType())->getNumElements();
const size_t rhs_lanes = llvm::cast<llvm::VectorType>(rhs->getType())->getNumElements();
if (lhs_lanes < rhs_lanes) {
lhs = CreateVecPad(lhs, rhs_lanes);
} else if (rhs_lanes < lhs_lanes) {
Expand Down Expand Up @@ -870,16 +869,16 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
return builder_->CreateFCmpUNO(a, a);
} else if (op->is_intrinsic("vectorlow")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, 0, l/2);
} else if (op->is_intrinsic("vectorhigh")) {
llvm::Value *v = MakeValue(op->args[0]);
int l = v->getType()->getVectorNumElements();
int l = llvm::cast<llvm::VectorType>(v->getType())->getNumElements();
return CreateVecSlice(v, l/2, l/2);
} else if (op->is_intrinsic("vectorcombine")) {
llvm::Value *v0 = MakeValue(op->args[0]);
llvm::Value *v1 = MakeValue(op->args[1]);
int num_elems = static_cast<int>(v0->getType()->getVectorNumElements()) * 2;
int num_elems = llvm::cast<llvm::VectorType>(v0->getType())->getNumElements() * 2;
#if TVM_LLVM_VERSION >= 110
std::vector<int> indices;
#else
Expand Down
13 changes: 6 additions & 7 deletions src/target/llvm/codegen_x86_64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,21 +123,20 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr

const std::vector<llvm::Value*>& args) {
llvm::Function* f = llvm::Intrinsic::getDeclaration(module_.get(), id, {});
if (intrin_lanes == result_ty->getVectorNumElements()) {
size_t num_elems = llvm::cast<llvm::VectorType>(result_ty)->getNumElements();
if (intrin_lanes == num_elems) {
return builder_->CreateCall(f, args);
}

// Otherwise, we split the vector into intrin_lanes sized elements (widening where necessary),
// compute each result, and then concatenate the vectors (slicing the result if necessary).
CHECK_LT(intrin_lanes, result_ty->getVectorNumElements());
CHECK_LT(intrin_lanes, num_elems);
std::vector<llvm::Value*> split_results;
for (size_t i = 0;
i < static_cast<size_t>(result_ty->getVectorNumElements());
i += intrin_lanes) {
for (size_t i = 0; i < num_elems; i += intrin_lanes) {
std::vector<llvm::Value*> split_args;
for (const auto& v : args) {
if (v->getType()->isVectorTy()) {
CHECK_EQ(v->getType()->getVectorNumElements(), result_ty->getVectorNumElements());
CHECK_EQ(llvm::cast<llvm::VectorType>(v->getType())->getNumElements(), num_elems);
split_args.push_back(CreateVecSlice(v, i, intrin_lanes));
} else {
split_args.push_back(v);
Expand All @@ -147,7 +146,7 @@ llvm::Value* CodeGenX86_64::CallVectorIntrin(llvm::Intrinsic::ID id, size_t intr
id, intrin_lanes, llvm::VectorType::get(result_ty->getScalarType(), intrin_lanes),
split_args));
}
return CreateVecSlice(CreateVecConcat(split_results), 0, result_ty->getVectorNumElements());
return CreateVecSlice(CreateVecConcat(split_results), 0, num_elems);
}

TVM_REGISTER_GLOBAL("tvm.codegen.llvm.target_x86-64")
Expand Down

0 comments on commit 5f4f523

Please sign in to comment.