Skip to content

Commit

Permalink
Refactor2023: Remove dependencies on Program::this_thread_config() in…
Browse files Browse the repository at this point in the history
… KernelCodeGen
  • Loading branch information
PGZXB committed Dec 6, 2022
1 parent a2f1e15 commit 187f206
Show file tree
Hide file tree
Showing 15 changed files with 54 additions and 31 deletions.
23 changes: 13 additions & 10 deletions taichi/codegen/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,27 +21,30 @@

namespace taichi::lang {

KernelCodeGen::KernelCodeGen(Kernel *kernel)
: prog(kernel->program), kernel(kernel) {
KernelCodeGen::KernelCodeGen(const CompileConfig *compile_config,
Kernel *kernel)
: prog(kernel->program), kernel(kernel), compile_config_(compile_config) {
this->ir = kernel->ir.get();
}

std::unique_ptr<KernelCodeGen> KernelCodeGen::create(Arch arch,
Kernel *kernel) {
std::unique_ptr<KernelCodeGen> KernelCodeGen::create(
const CompileConfig *compile_config,
Kernel *kernel) {
#ifdef TI_WITH_LLVM
const auto arch = compile_config->arch;
if (arch_is_cpu(arch) && arch != Arch::wasm) {
return std::make_unique<KernelCodeGenCPU>(kernel);
return std::make_unique<KernelCodeGenCPU>(compile_config, kernel);
} else if (arch == Arch::wasm) {
return std::make_unique<KernelCodeGenWASM>(kernel);
return std::make_unique<KernelCodeGenWASM>(compile_config, kernel);
} else if (arch == Arch::cuda) {
#if defined(TI_WITH_CUDA)
return std::make_unique<KernelCodeGenCUDA>(kernel);
return std::make_unique<KernelCodeGenCUDA>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
} else if (arch == Arch::dx12) {
#if defined(TI_WITH_DX12)
return std::make_unique<KernelCodeGenDX12>(kernel);
return std::make_unique<KernelCodeGenDX12>(compile_config, kernel);
#else
TI_NOT_IMPLEMENTED
#endif
Expand All @@ -58,7 +61,7 @@ std::optional<LLVMCompiledKernel>
KernelCodeGen::maybe_read_compilation_from_cache(
const std::string &kernel_key) {
TI_AUTO_PROF;
const auto &config = prog->this_thread_config();
const auto &config = *compile_config_;
auto *llvm_prog = get_llvm_program(prog);
const auto &reader = llvm_prog->get_cache_reader();
if (!reader) {
Expand All @@ -84,7 +87,7 @@ void KernelCodeGen::cache_kernel(const std::string &kernel_key,

LLVMCompiledKernel KernelCodeGen::compile_kernel_to_module() {
auto *llvm_prog = get_llvm_program(prog);
const auto &config = prog->this_thread_config();
const auto &config = *compile_config_;
auto *tlctx = llvm_prog->get_llvm_context(config.arch);
std::string kernel_key = get_hashed_offline_cache_key(&config, kernel);
kernel->set_kernel_key_for_cache(kernel_key);
Expand Down
6 changes: 4 additions & 2 deletions taichi/codegen/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class KernelCodeGen {
IRNode *ir;

public:
explicit KernelCodeGen(Kernel *kernel);
explicit KernelCodeGen(const CompileConfig *compile_config, Kernel *kernel);

virtual ~KernelCodeGen() = default;

static std::unique_ptr<KernelCodeGen> create(Arch arch, Kernel *kernel);
static std::unique_ptr<KernelCodeGen> create(const CompileConfig *compile_config, Kernel *kernel);

virtual FunctionType compile_to_function() = 0;
virtual bool supports_offline_cache() const {
Expand All @@ -65,6 +65,8 @@ class KernelCodeGen {
void cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data);
#endif
private:
const CompileConfig *compile_config_{nullptr};
};

#ifdef TI_WITH_LLVM
Expand Down
3 changes: 2 additions & 1 deletion taichi/codegen/cpu/codegen_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace taichi::lang {

class KernelCodeGenCPU : public KernelCodeGen {
public:
explicit KernelCodeGenCPU(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCPU(const CompileConfig *compile_config, Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/cuda/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ namespace taichi::lang {

class KernelCodeGenCUDA : public KernelCodeGen {
public:
explicit KernelCodeGenCUDA(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenCUDA(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

// TODO: Stop defining this macro guards in the headers
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/dx12/codegen_dx12.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenDX12 : public KernelCodeGen {
public:
explicit KernelCodeGenDX12(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenDX12(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}
struct CompileResult {
std::vector<std::vector<uint8_t>> task_dxil_source_codes;
Expand Down
4 changes: 3 additions & 1 deletion taichi/codegen/wasm/codegen_wasm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ namespace taichi::lang {

class KernelCodeGenWASM : public KernelCodeGen {
public:
explicit KernelCodeGenWASM(Kernel *kernel) : KernelCodeGen(kernel) {
explicit KernelCodeGenWASM(const CompileConfig *compile_config,
Kernel *kernel)
: KernelCodeGen(compile_config, kernel) {
}

FunctionType compile_to_function() override;
Expand Down
2 changes: 1 addition & 1 deletion taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ std::unique_ptr<AotModuleBuilder> Program::make_aot_module_builder(
if (arch == Arch::wasm) {
// Have to check WASM first, or it dispatches to the LlvmProgramImpl.
#ifdef TI_WITH_LLVM
return std::make_unique<wasm::AotModuleBuilderImpl>();
return std::make_unique<wasm::AotModuleBuilderImpl>(&this_thread_config());
#else
TI_NOT_IMPLEMENTED
#endif
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cpu/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cpu {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCPU(kernel);
auto cgen = KernelCodeGenCPU(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
5 changes: 3 additions & 2 deletions taichi/runtime/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ namespace cpu {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/cuda/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace taichi::lang {
namespace cuda {

LLVMCompiledKernel AotModuleBuilderImpl::compile_kernel(Kernel *kernel) {
auto cgen = KernelCodeGenCUDA(kernel);
auto cgen = KernelCodeGenCUDA(get_compile_config(), kernel);
return cgen.compile_kernel_to_module();
}

Expand Down
4 changes: 2 additions & 2 deletions taichi/runtime/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace cuda {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
explicit AotModuleBuilderImpl(const CompileConfig *compile_config, LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(compile_config, prog) {
}

private:
Expand Down
9 changes: 8 additions & 1 deletion taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ namespace taichi::lang {

class LlvmAotModuleBuilder : public AotModuleBuilder {
public:
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
explicit LlvmAotModuleBuilder(const CompileConfig *compile_config,
LlvmProgramImpl *prog)
: compile_config_(compile_config), prog_(prog) {
}

void dump(const std::string &output_dir,
Expand All @@ -30,8 +32,13 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {
return cache_;
}

const CompileConfig *get_compile_config() const {
return compile_config_;
}

private:
mutable LlvmOfflineCache cache_;
const CompileConfig *compile_config_{nullptr};
LlvmProgramImpl *prog_ = nullptr;
};

Expand Down
8 changes: 4 additions & 4 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ LlvmProgramImpl::LlvmProgramImpl(CompileConfig &config_,
}

FunctionType LlvmProgramImpl::compile(Kernel *kernel) {
auto codegen = KernelCodeGen::create(config->arch, kernel);
auto codegen = KernelCodeGen::create(config, kernel);
return codegen->compile_to_function();
}

Expand Down Expand Up @@ -89,18 +89,18 @@ void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder(
const DeviceCapabilityConfig &caps) {
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
return std::make_unique<cpu::AotModuleBuilderImpl>(config, this);
}

#if defined(TI_WITH_CUDA)
if (config->arch == Arch::cuda) {
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
return std::make_unique<cuda::AotModuleBuilderImpl>(config, this);
}
#endif

#if defined(TI_WITH_DX12)
if (config->arch == Arch::dx12) {
return std::make_unique<directx12::AotModuleBuilderImpl>(this);
return std::make_unique<directx12::AotModuleBuilderImpl>(config, this);
}
#endif

Expand Down
6 changes: 4 additions & 2 deletions taichi/runtime/wasm/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
namespace taichi::lang {
namespace wasm {

AotModuleBuilderImpl::AotModuleBuilderImpl() : module_(nullptr) {
AotModuleBuilderImpl::AotModuleBuilderImpl(const CompileConfig *compile_config)
: compile_config_(compile_config), module_(nullptr) {
TI_AUTO_PROF
}

Expand All @@ -35,7 +36,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto module_info = KernelCodeGenWASM(kernel).compile_kernel_to_module();
auto module_info =
KernelCodeGenWASM(compile_config_, kernel).compile_kernel_to_module();
if (module_) {
llvm::Linker::linkModules(*module_, std::move(module_info.module),
llvm::Linker::OverrideFromSrc);
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/wasm/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace wasm {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl();
explicit AotModuleBuilderImpl(const CompileConfig *compile_config);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand All @@ -34,6 +34,7 @@ class AotModuleBuilderImpl : public AotModuleBuilder {

private:
void eliminate_unused_functions() const;
const CompileConfig *compile_config_{nullptr};
std::unique_ptr<llvm::Module> module_{nullptr};
std::vector<std::string> name_list_;
};
Expand Down

0 comments on commit 187f206

Please sign in to comment.