Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] [cuda] Refactor offline-cache and support it on arch=cuda #4600

Merged
merged 5 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ class CodeGenLLVMCPU : public CodeGenLLVM {
public:
using IRVisitor::visit;

CodeGenLLVMCPU(Kernel *kernel, IRNode *ir, bool needs_cache)
: CodeGenLLVM(kernel, ir, nullptr, needs_cache) {
CodeGenLLVMCPU(Kernel *kernel, IRNode *ir)
: CodeGenLLVM(kernel, ir, nullptr) {
TI_AUTO_PROF
}

bool supports_offline_cache() const override {
return true;
}

void create_offload_range_for(OffloadedStmt *stmt) override {
int step = 1;

Expand Down Expand Up @@ -195,7 +199,7 @@ class CodeGenLLVMCPU : public CodeGenLLVM {

FunctionType CodeGenCPU::codegen() {
TI_AUTO_PROF
return CodeGenLLVMCPU(kernel, ir, needs_cache_).gen();
return CodeGenLLVMCPU(kernel, ir).gen();
}

TLANG_NAMESPACE_END
6 changes: 1 addition & 5 deletions taichi/backends/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,10 @@ TLANG_NAMESPACE_BEGIN

class CodeGenCPU : public KernelCodeGen {
public:
CodeGenCPU(Kernel *kernel, IRNode *ir = nullptr, bool needs_cache = false)
: KernelCodeGen(kernel, ir), needs_cache_(needs_cache) {
CodeGenCPU(Kernel *kernel, IRNode *ir = nullptr) : KernelCodeGen(kernel, ir) {
}

FunctionType codegen() override;

private:
bool needs_cache_{false};
};

TLANG_NAMESPACE_END
6 changes: 4 additions & 2 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
: CodeGenLLVM(kernel, ir) {
}

bool supports_offline_cache() const override {
return true;
}

FunctionType compile_module_to_executable() override {
#ifdef TI_WITH_CUDA
eliminate_unused_functions();

auto offloaded_local = offloaded_tasks;
for (auto &task : offloaded_local) {
llvm::Function *func = module->getFunction(task.name);
Expand Down
4 changes: 4 additions & 0 deletions taichi/backends/wasm/codegen_wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ class CodeGenLLVMWASM : public CodeGenLLVM {

FunctionType gen() override {
TI_AUTO_PROF
// lower kernel
if (!kernel->lowered()) {
kernel->lower();
}
// emit_to_module
stat.add("codegen_taichi_kernel_function");
auto offloaded_task_name = init_taichi_kernel_function();
Expand Down
5 changes: 2 additions & 3 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,10 @@ KernelCodeGen::KernelCodeGen(Kernel *kernel, IRNode *ir)

std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
Kernel *kernel,
Stmt *stmt,
bool needs_cache) {
Stmt *stmt) {
#ifdef TI_WITH_LLVM
if (arch_is_cpu(arch) && arch != Arch::wasm) {
return std::make_unique<CodeGenCPU>(kernel, stmt, needs_cache);
return std::make_unique<CodeGenCPU>(kernel, stmt);
} else if (arch == Arch::wasm) {
return std::make_unique<CodeGenWASM>(kernel, stmt);
} else if (arch == Arch::cuda) {
Expand Down
3 changes: 1 addition & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class KernelCodeGen {

static std::unique_ptr<KernelCodeGen> create(Arch arch,
Kernel *kernel,
Stmt *stmt = nullptr,
bool needs_cache = false);
Stmt *stmt = nullptr);

virtual FunctionType codegen() = 0;
};
Expand Down
64 changes: 49 additions & 15 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#ifdef TI_WITH_LLVM
#include "taichi/codegen/codegen_llvm.h"

#include "taichi/llvm/llvm_offline_cache.h"
#include "taichi/ir/statements.h"
#include "taichi/struct/struct_llvm.h"
#include "taichi/util/file_sequence_writer.h"
Expand Down Expand Up @@ -304,8 +304,7 @@ void CodeGenLLVM::emit_struct_meta_base(const std::string &name,

CodeGenLLVM::CodeGenLLVM(Kernel *kernel,
IRNode *ir,
std::unique_ptr<llvm::Module> &&module,
bool needs_cache)
std::unique_ptr<llvm::Module> &&module)
// TODO: simplify LLVMModuleBuilder ctor input
: LLVMModuleBuilder(
module == nullptr ? kernel->program->get_llvm_program_impl()
Expand All @@ -314,7 +313,6 @@ CodeGenLLVM::CodeGenLLVM(Kernel *kernel,
: std::move(module),
kernel->program->get_llvm_program_impl()->get_llvm_context(
kernel->arch)),
needs_cache_(needs_cache),
kernel(kernel),
ir(ir),
prog(kernel->program) {
Expand Down Expand Up @@ -2268,17 +2266,6 @@ void CodeGenLLVM::eliminate_unused_functions() {

FunctionType CodeGenLLVM::compile_module_to_executable() {
TI_AUTO_PROF
eliminate_unused_functions();

auto *llvm_prog = prog->get_llvm_program_impl();
if (needs_cache_) {
std::vector<std::string> offloaded_task_name_list;
for (auto &task : offloaded_tasks) {
offloaded_task_name_list.push_back(task.name);
}
llvm_prog->cache_kernel(this->kernel->get_key(), this->module.get(),
std::move(offloaded_task_name_list));
}

tlctx->add_module(std::move(module));

Expand Down Expand Up @@ -2384,7 +2371,42 @@ void CodeGenLLVM::emit_to_module() {
}

FunctionType CodeGenLLVM::gen() {
bool needs_cache = false;
const auto &config = prog->config;
std::string kernel_key;
if (config.offline_cache && this->supports_offline_cache() &&
!kernel->is_evaluator) {
kernel_key = get_offline_cache_key_of_kernel(kernel);

LlvmOfflineCacheFileReader reader(config.offline_cache_file_path);
LlvmOfflineCache::KernelCacheData cache_data;
auto *tlctx =
this->prog->get_llvm_program_impl()->get_llvm_context(config.arch);
auto &llvm_ctx = *tlctx->get_this_thread_context();

if (reader.get_kernel_cache(cache_data, kernel_key, llvm_ctx)) {
this->module = std::move(cache_data.owned_module);
for (auto &task : cache_data.offloaded_task_list) {
auto &t = this->offloaded_tasks.emplace_back(this);
t.name = std::move(task.name);
t.block_dim = task.block_dim;
t.grid_dim = task.grid_dim;
}
kernel->set_from_offline_cache();
return compile_module_to_executable();
} else {
needs_cache = true;
}
}

if (!kernel->lowered()) {
kernel->lower();
}
emit_to_module();
eliminate_unused_functions();
if (needs_cache) {
cache_module(kernel_key);
}
return compile_module_to_executable();
}

Expand Down Expand Up @@ -2451,6 +2473,18 @@ void CodeGenLLVM::visit(FuncCallStmt *stmt) {
}
}

void CodeGenLLVM::cache_module(const std::string &kernel_key) {
using OffloadedTaskCache = LlvmOfflineCache::OffloadedTaskCacheData;
std::vector<OffloadedTaskCache> offloaded_task_list;
for (auto &task : offloaded_tasks) {
auto &task_cache = offloaded_task_list.emplace_back();
task_cache.name = task.name;
task_cache.block_dim = task.block_dim;
task_cache.grid_dim = task.grid_dim;
}
prog->get_llvm_program_impl()->cache_kernel(kernel_key, this->module.get(),
std::move(offloaded_task_list));
}
TLANG_NAMESPACE_END

#endif // #ifdef TI_WITH_LLVM
17 changes: 10 additions & 7 deletions taichi/codegen/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class OffloadedTask {
using task_fp_type = int32 (*)(void *);
task_fp_type func;

int block_dim;
int grid_dim;
int block_dim{0};
int grid_dim{0};

OffloadedTask(CodeGenLLVM *codegen);

Expand All @@ -48,9 +48,6 @@ class FunctionCreationGuard {
};

class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
private:
bool needs_cache_{false};

public:
Kernel *kernel;
IRNode *ir;
Expand Down Expand Up @@ -86,8 +83,7 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

CodeGenLLVM(Kernel *kernel,
IRNode *ir = nullptr,
std::unique_ptr<llvm::Module> &&module = nullptr,
bool needs_cache = false);
std::unique_ptr<llvm::Module> &&module = nullptr);

Arch current_arch() {
return kernel->arch;
Expand Down Expand Up @@ -131,6 +127,10 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {

virtual FunctionType gen();

virtual bool supports_offline_cache() const {
return false;
}

// For debugging only
virtual llvm::Value *create_print(std::string tag,
DataType dt,
Expand Down Expand Up @@ -391,6 +391,9 @@ class CodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *bitcast_to_u64(llvm::Value *val, DataType type);

~CodeGenLLVM() override = default;

private:
void cache_module(const std::string &kernel_key);
};

TLANG_NAMESPACE_END
Expand Down
34 changes: 25 additions & 9 deletions taichi/llvm/llvm_offline_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,22 @@
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/IR/Module.h"
#include "taichi/ir/transforms.h"

#include "picosha2.h"

namespace taichi {
namespace lang {

std::string get_offline_cache_key_of_kernel(Kernel *kernel) {
std::string res, kernel_ast_string;
irpass::re_id(kernel->ir.get());
irpass::print(kernel->ir.get(), &kernel_ast_string);
picosha2::hash256_hex_string(kernel_ast_string, res);
res.insert(res.begin(), kernel->grad ? 'g' : 'n');
return res;
}

bool LlvmOfflineCacheFileReader::get_kernel_cache(
LlvmOfflineCache::KernelCacheData &res,
const std::string &key,
Expand All @@ -32,7 +44,9 @@ bool LlvmOfflineCacheFileReader::get_kernel_cache(
std::getline(in, line, '\n');
if (line.empty())
break;
res.offloaded_task_name_list.push_back(std::move(line));
std::istringstream iss(line);
auto &task = res.offloaded_task_list.emplace_back();
iss >> task.name >> task.block_dim >> task.grid_dim;
}
}
return true;
Expand All @@ -49,11 +63,11 @@ void LlvmOfflineCacheFileWriter::dump() {
llvm::LLVMContext ctx;
llvm::raw_os_ostream llvm_os(os);
if (v.module) {
mangle_offloaded_task_name(k, v.module, v.offloaded_task_name_list);
mangle_offloaded_task_name(k, v.module, v.offloaded_task_list);
v.module->print(llvm_os, nullptr);
} else if (v.owned_module) {
mangle_offloaded_task_name(k, v.owned_module.get(),
v.offloaded_task_name_list);
v.offloaded_task_list);
v.owned_module->print(llvm_os, nullptr);
} else
TI_ASSERT(false);
Expand All @@ -62,8 +76,9 @@ void LlvmOfflineCacheFileWriter::dump() {
std::string filename = filename_prefix + "_otnl.txt";
std::ofstream os(filename, std::ios::out | std::ios::binary);
TI_ERROR_IF(!os.is_open(), "File {} open failed", filename);
for (const auto &name : v.offloaded_task_name_list) {
os << name << '\n';
for (const auto &task : v.offloaded_task_list) {
os << task.name << ' ' << task.block_dim << ' ' << task.grid_dim
<< '\n';
}
}
}
Expand All @@ -72,15 +87,16 @@ void LlvmOfflineCacheFileWriter::dump() {
void LlvmOfflineCacheFileWriter::mangle_offloaded_task_name(
const std::string &kernel_key,
llvm::Module *module,
std::vector<std::string> &offloaded_task_name_list) {
std::vector<LlvmOfflineCache::OffloadedTaskCacheData>
&offloaded_task_list) {
if (!mangled_) {
std::size_t cnt = 0;
for (auto &e : offloaded_task_name_list) {
for (auto &e : offloaded_task_list) {
std::string mangled_name = kernel_key + std::to_string(cnt++);
auto func = module->getFunction(e);
auto func = module->getFunction(e.name);
TI_ASSERT(func != nullptr);
func->setName(mangled_name);
e = mangled_name;
e.name = mangled_name;
}
}
}
Expand Down
13 changes: 11 additions & 2 deletions taichi/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
#pragma once

#include "taichi/common/core.h"
#include "taichi/program/kernel.h"
#include "taichi/llvm/llvm_fwd.h"
#include "taichi/util/io.h"

namespace taichi {
namespace lang {

std::string get_offline_cache_key_of_kernel(Kernel *kernel);

struct LlvmOfflineCache {
struct OffloadedTaskCacheData {
std::string name;
int block_dim{0};
int grid_dim{0};
};
struct KernelCacheData {
std::string kernel_key;
std::unique_ptr<llvm::Module> owned_module{nullptr};
llvm::Module *module{nullptr};
std::vector<std::string> offloaded_task_name_list;
std::vector<OffloadedTaskCacheData> offloaded_task_list;

KernelCacheData() = default;
KernelCacheData(KernelCacheData &&) = default;
Expand Down Expand Up @@ -58,7 +66,8 @@ class LlvmOfflineCacheFileWriter {
void mangle_offloaded_task_name(
const std::string &kernel_key,
llvm::Module *module,
std::vector<std::string> &offloaded_task_name_list);
std::vector<LlvmOfflineCache::OffloadedTaskCacheData>
&offloaded_task_list);

std::string path_;
LlvmOfflineCache data_;
Expand Down
Loading