Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llvm] Avoid creating new LLVM contexts when updating struct module #5397

Merged
merged 11 commits into from
Jul 13, 2022
5 changes: 3 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -925,8 +925,9 @@ void CodeGenLLVM::emit_gc(OffloadedStmt *stmt) {
}

llvm::Value *CodeGenLLVM::create_call(llvm::Value *func,
llvm::ArrayRef<llvm::Value *> args) {
check_func_call_signature(func, args);
llvm::ArrayRef<llvm::Value *> args_arr) {
std::vector<llvm::Value *> args = args_arr;
check_func_call_signature(func, args, builder.get());
#ifdef TI_LLVM_15
llvm::FunctionType *func_ty = nullptr;
if (auto *fn = llvm::dyn_cast<llvm::Function>(func)) {
Expand Down
63 changes: 61 additions & 2 deletions taichi/codegen/llvm/llvm_codegen_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,59 @@ namespace lang {
std::string type_name(llvm::Type *type) {
std::string type_name_str;
llvm::raw_string_ostream rso(type_name_str);
type->print(rso);
type->print(rso, /*IsForDebug=*/false, /*NoDetails=*/true);
return type_name_str;
}

/*
* Determine whether two types are the same
* (a type is a renamed version of the other one) based on the
* type name. Check recursively if the types are function types.
*
* The name of a type imported multiple times is added a suffix starting with a
* "." following by a number. For example, "RuntimeContext" may be renamed to
* names like "RuntimeContext.0" and "RuntimeContext.8".
*/
bool is_same_type(llvm::Type *required, llvm::Type *provided) {
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
if (required == provided) {
return true;
}
if (required->isPointerTy() != provided->isPointerTy()) {
return false;
}
if (required->isPointerTy()) {
required = required->getPointerElementType();
provided = provided->getPointerElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it cleaner to just return is_same_type(required->getPointerElementType(), provided->getPointerElementType()) here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... I'll send another pr to fix this.

}
if (required->isFunctionTy() != provided->isFunctionTy()) {
return false;
}
if (required->isFunctionTy()) {
auto req_func = llvm::dyn_cast<llvm::FunctionType>(required);
auto prov_func = llvm::dyn_cast<llvm::FunctionType>(provided);
if (!is_same_type(req_func->getReturnType(), prov_func->getReturnType())) {
return false;
}
if (req_func->getNumParams() != prov_func->getNumParams()) {
return false;
}
for (int j = 0; j < req_func->getNumParams(); j++) {
if (!is_same_type(req_func->getParamType(j),
prov_func->getParamType(j))) {
return false;
}
}
return true;
}
auto req_name = type_name(required);
auto prov_name = type_name(provided);
int min_len = std::min(req_name.size(), prov_name.size());
return req_name.substr(0, min_len) == prov_name.substr(0, min_len);
}

void check_func_call_signature(llvm::Value *func,
std::vector<llvm::Value *> arglist) {
std::vector<llvm::Value *> &arglist,
llvm::IRBuilder<> *builder) {
llvm::FunctionType *func_type = nullptr;
if (llvm::Function *fn = llvm::dyn_cast<llvm::Function>(func)) {
func_type = fn->getFunctionType();
Expand All @@ -31,7 +78,19 @@ void check_func_call_signature(llvm::Value *func,
for (int i = 0; i < num_params; i++) {
auto required = func_type->getFunctionParamType(i);
auto provided = arglist[i]->getType();
/*
jim19930609 marked this conversation as resolved.
Show resolved Hide resolved
* When importing a module from file, the imported `llvm::Type`s can get
* conflict with the same type in the llvm::Context. In such scenario,
* the imported types will be renamed from "original_type" to
* "original_type.xxx", making them separate types in essence.
* To make the types of the argument and parameter the same,
* a pointer cast must be performed.
*/
if (required != provided) {
if (is_same_type(required, provided)) {
arglist[i] = builder->CreatePointerCast(arglist[i], required);
continue;
}
TI_INFO("Function : {}", std::string(func->getName()));
TI_INFO(" Type : {}", type_name(func->getType()));
if (&required->getContext() != &provided->getContext()) {
Expand Down
12 changes: 7 additions & 5 deletions taichi/codegen/llvm/llvm_codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ inline constexpr char kLLVMPhysicalCoordinatesName[] = "PhysicalCoordinates";
std::string type_name(llvm::Type *type);

void check_func_call_signature(llvm::Value *func,
std::vector<llvm::Value *> arglist);
std::vector<llvm::Value *> &arglist,
llvm::IRBuilder<> *builder);

class LLVMModuleBuilder {
public:
Expand Down Expand Up @@ -121,8 +122,9 @@ class LLVMModuleBuilder {
const std::string &func_name,
const std::vector<llvm::Value *> &arglist) {
auto func = get_runtime_function(func_name);
check_func_call_signature(func, arglist);
return builder->CreateCall(func, arglist);
std::vector<llvm::Value *> args = arglist;
check_func_call_signature(func, args, builder);
return builder->CreateCall(func, args);
}

template <typename... Args>
Expand All @@ -131,7 +133,7 @@ class LLVMModuleBuilder {
Args &&...args) {
auto func = get_runtime_function(func_name);
auto arglist = std::vector<llvm::Value *>({args...});
check_func_call_signature(func, arglist);
check_func_call_signature(func, arglist, builder);
return builder->CreateCall(func, arglist);
}

Expand Down Expand Up @@ -186,7 +188,7 @@ class RuntimeObject {
llvm::Value *call(const std::string &func_name, Args &&...args) {
auto func = get_func(func_name);
auto arglist = std::vector<llvm::Value *>({ptr, args...});
check_func_call_signature(func, arglist);
check_func_call_signature(func, arglist, builder);
return builder->CreateCall(func, arglist);
}

Expand Down
6 changes: 0 additions & 6 deletions taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -529,12 +529,6 @@ void TaichiLLVMContext::set_struct_module(
continue;
}
TI_ASSERT(!data->runtime_module);
data->struct_module.reset();
old_contexts_.push_back(std::move(data->thread_safe_llvm_context));
data->thread_safe_llvm_context =
std::make_unique<llvm::orc::ThreadSafeContext>(
std::make_unique<llvm::LLVMContext>());
data->llvm_context = data->thread_safe_llvm_context->getContext();
data->struct_module = clone_module_to_context(
this_thread_data->struct_module.get(), data->llvm_context);
}
Expand Down
3 changes: 0 additions & 3 deletions taichi/runtime/llvm/llvm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,6 @@ class TaichiLLVMContext {
std::mutex thread_map_mut_;

std::unordered_map<int, std::vector<std::string>> snode_tree_funcs_;
std::vector<std::unique_ptr<llvm::orc::ThreadSafeContext>>
old_contexts_; // Contains old contexts that have modules stored inside
// the offline cache.
};

class LlvmModuleBitcodeLoader {
Expand Down